Skip to content

Commit

Permalink
fix: re-enable the "compress_selectors" option
Browse files Browse the repository at this point in the history
  • Loading branch information
guorong009 committed May 6, 2024
1 parent 33b6dbd commit bbc773b
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 38 deletions.
14 changes: 11 additions & 3 deletions halo2_frontend/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub use table_layouter::{SimpleTableLayouter, TableLayouter};
pub fn compile_circuit<F: Field, ConcreteCircuit: Circuit<F>>(
k: u32,
circuit: &ConcreteCircuit,
compress_selectors: bool,
) -> Result<
(
CompiledCircuit<F>,
Expand All @@ -49,17 +50,17 @@ pub fn compile_circuit<F: Field, ConcreteCircuit: Circuit<F>>(
Error,
> {
let n = 2usize.pow(k);

let mut cs = ConstraintSystem::default();

#[cfg(feature = "circuit-params")]
let config = ConcreteCircuit::configure_with_params(&mut cs, circuit.params());
#[cfg(not(feature = "circuit-params"))]
let config = ConcreteCircuit::configure(&mut cs);
let cs = cs;

if n < cs.minimum_rows() {
return Err(Error::not_enough_rows_available(k));
}

let mut assembly = plonk::keygen::Assembly {
k,
fixed: vec![vec![F::ZERO.into(); n]; cs.num_fixed_columns],
Expand All @@ -78,7 +79,14 @@ pub fn compile_circuit<F: Field, ConcreteCircuit: Circuit<F>>(
)?;

let mut fixed = batch_invert_assigned(assembly.fixed);
let (cs, selector_polys) = cs.compress_selectors(assembly.selectors);
let (cs, selector_polys) = if compress_selectors {
cs.compress_selectors(assembly.selectors)
} else {
// After this, the ConstraintSystem should not have any selectors: `verify` does not need them, and `keygen_pk` regenerates `cs` from scratch anyways.
let selectors = std::mem::take(&mut assembly.selectors);
cs.directly_convert_selectors_to_fixed(selectors)
};

fixed.extend(selector_polys);

// sort the "copies" for deterministic ordering
Expand Down
33 changes: 33 additions & 0 deletions halo2_frontend/src/plonk/circuit/constraint_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,39 @@ impl<F: Field> ConstraintSystem<F> {
(self, polys)
}

/// Does not combine selectors and directly replaces them everywhere with fixed columns.
pub fn directly_convert_selectors_to_fixed(
mut self,
selectors: Vec<Vec<bool>>,
) -> (Self, Vec<Vec<F>>) {
// The number of provided selector assignments must be the number we
// counted for this constraint system.
assert_eq!(selectors.len(), self.num_selectors);

let (polys, selector_replacements): (Vec<_>, Vec<_>) = selectors
.into_iter()
.map(|selector| {
let poly = selector
.iter()
.map(|b| if *b { F::ONE } else { F::ZERO })
.collect::<Vec<_>>();
let column = self.fixed_column();
let rotation = Rotation::cur();
let expr = Expression::Fixed(FixedQuery {
index: Some(self.query_fixed_index(column, rotation)),
column_index: column.index,
rotation,
});
(poly, expr)
})
.unzip();

self.replace_selectors_with_fixed(&selector_replacements);
self.num_selectors = 0;

(self, polys)
}

fn replace_selectors_with_fixed(&mut self, selector_replacements: &[Expression<F>]) {
fn replace_selectors<F: Field>(
expr: &mut Expression<F>,
Expand Down
16 changes: 11 additions & 5 deletions halo2_proofs/examples/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ use ff::Field;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{
create_proof, keygen_pk, keygen_vk, pk_read, verify_proof, Advice, Circuit, Column,
ConstraintSystem, ErrorFront, Fixed, Instance,
create_proof, keygen_pk, keygen_vk_custom, pk_read, verify_proof, Advice, Circuit, Column, ConstraintSystem, ErrorFront, Fixed, Instance
},
poly::{
kzg::{
Expand Down Expand Up @@ -141,7 +140,8 @@ fn main() {
let k = 4;
let circuit = StandardPlonk(Fr::random(OsRng));
let params = ParamsKZG::<Bn256>::setup(k, OsRng);
let vk = keygen_vk(&params, &circuit).expect("vk should not fail");
let compress_selectors = true;
let vk = keygen_vk_custom(&params, &circuit, compress_selectors).expect("vk should not fail");
let pk = keygen_pk(&params, vk, &circuit).expect("pk should not fail");

let f = File::create("serialization-test.pk").unwrap();
Expand All @@ -152,8 +152,14 @@ fn main() {
let f = File::open("serialization-test.pk").unwrap();
let mut reader = BufReader::new(f);
#[allow(clippy::unit_arg)]
let pk = pk_read::<G1Affine, _, StandardPlonk>(&mut reader, SerdeFormat::RawBytes, k, &circuit)
.unwrap();
let pk = pk_read::<G1Affine, _, StandardPlonk>(
&mut reader,
SerdeFormat::RawBytes,
k,
&circuit,
compress_selectors,
)
.unwrap();

std::fs::remove_file("serialization-test.pk").unwrap();

Expand Down
10 changes: 6 additions & 4 deletions halo2_proofs/src/plonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ mod verifier {
}

use halo2_frontend::circuit::compile_circuit;
pub use keygen::{keygen_pk, keygen_vk};
pub use keygen::{keygen_pk, keygen_vk_custom, keygen_vk};

pub use prover::{create_proof, create_proof_with_engine};
pub use prover::{create_proof, create_proof_custom_with_engine, create_proof_with_engine};
pub use verifier::verify_proof;

pub use error::Error;
Expand Down Expand Up @@ -46,11 +46,12 @@ pub fn vk_read<C: SerdeCurveAffine, R: io::Read, ConcreteCircuit: Circuit<C::Sca
format: SerdeFormat,
k: u32,
circuit: &ConcreteCircuit,
compress_selectors: bool,
) -> io::Result<VerifyingKey<C>>
where
C::Scalar: SerdePrimeField + FromUniformBytes<64>,
{
let (_, _, cs) = compile_circuit(k, circuit)
let (_, _, cs) = compile_circuit(k, circuit, compress_selectors)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
let cs_mid: ConstraintSystemMid<_> = cs.into();
VerifyingKey::read(reader, format, cs_mid.into())
Expand All @@ -73,11 +74,12 @@ pub fn pk_read<C: SerdeCurveAffine, R: io::Read, ConcreteCircuit: Circuit<C::Sca
format: SerdeFormat,
k: u32,
circuit: &ConcreteCircuit,
compress_selectors: bool,
) -> io::Result<ProvingKey<C>>
where
C::Scalar: SerdePrimeField + FromUniformBytes<64>,
{
let (_, _, cs) = compile_circuit(k, circuit)
let (_, _, cs) = compile_circuit(k, circuit, compress_selectors)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
let cs_mid: ConstraintSystemMid<_> = cs.into();
ProvingKey::read(reader, format, cs_mid.into())
Expand Down
48 changes: 46 additions & 2 deletions halo2_proofs/src/plonk/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,29 @@ where
ConcreteCircuit: Circuit<C::Scalar>,
C::Scalar: FromUniformBytes<64>,
{
let (compiled_circuit, _, _) = compile_circuit(params.k(), circuit)?;
keygen_vk_custom(params, circuit, true)
}

/// Generate a `VerifyingKey` from an instance of `Circuit`.
///
/// The selector compression optimization is turned on only if `compress_selectors` is `true`.
///
/// **NOTE**: This `keygen_vk_custom` MUST share the same `compress_selectors` with
/// `ProvingKey` generation process.
/// Otherwise, the user could get unmatching pk/vk pair.
/// Hence, it is HIGHLY recommended to pair this util with `keygen_pk_custom`.
pub fn keygen_vk_custom<'params, C, P, ConcreteCircuit>(
params: &P,
circuit: &ConcreteCircuit,
compress_selectors: bool,
) -> Result<VerifyingKey<C>, Error>
where
C: CurveAffine,
P: Params<'params, C>,
ConcreteCircuit: Circuit<C::Scalar>,
C::Scalar: FromUniformBytes<64>,
{
let (compiled_circuit, _, _) = compile_circuit(params.k(), circuit, compress_selectors)?;
Ok(backend_keygen_vk(params, &compiled_circuit)?)
}

Expand All @@ -44,6 +66,28 @@ where
P: Params<'params, C>,
ConcreteCircuit: Circuit<C::Scalar>,
{
let (compiled_circuit, _, _) = compile_circuit(params.k(), circuit)?;
keygen_pk_custom(params, vk, circuit, true)
}

/// Generate a `ProvingKey` from an instance of `Circuit`.
///
/// The selector compression optimization is turned on only if `compress_selectors` is `true`.
///
/// **NOTE**: This `keygen_pk_custom` MUST share the same `compress_selectors` with
/// `VerifyingKey` generation process.
/// Otherwise, the user could get unmatching pk/vk pair.
/// Hence, it is HIGHLY recommended to pair this util with `keygen_vk_custom`.
pub fn keygen_pk_custom<'params, C, P, ConcreteCircuit>(
params: &P,
vk: VerifyingKey<C>,
circuit: &ConcreteCircuit,
compress_selectors: bool,
) -> Result<ProvingKey<C>, Error>
where
C: CurveAffine,
P: Params<'params, C>,
ConcreteCircuit: Circuit<C::Scalar>,
{
let (compiled_circuit, _, _) = compile_circuit(params.k(), circuit, compress_selectors)?;
Ok(backend_keygen_pk(params, vk, &compiled_circuit)?)
}
80 changes: 57 additions & 23 deletions halo2_proofs/src/plonk/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,9 @@ pub fn create_proof_with_engine<
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
if circuits.len() != instances.len() {
return Err(Error::Backend(ErrorBack::InvalidInstances));
}

let (_, config, cs) = compile_circuit::<_, ConcreteCircuit>(params.k(), &circuits[0])?;
let mut witness_calcs: Vec<_> = circuits
.iter()
.enumerate()
.map(|(i, circuit)| WitnessCalculator::new(params.k(), circuit, &config, &cs, instances[i]))
.collect();
let mut prover = Prover::<Scheme, P, _, _, _, _>::new_with_engine(
engine, params, pk, instances, rng, transcript,
)?;
let mut challenges = HashMap::new();
let phases = prover.phases().to_vec();
for phase in phases.iter() {
let mut witnesses = Vec::with_capacity(circuits.len());
for witness_calc in witness_calcs.iter_mut() {
witnesses.push(witness_calc.calc(*phase, &challenges)?);
}
challenges = prover.commit_phase(*phase, witnesses).unwrap();
}
Ok(prover.create_proof()?)
create_proof_custom_with_engine::<Scheme, P, E, R, T, ConcreteCircuit, M>(
engine, params, pk, true, circuits, instances, rng, transcript,
)
}

/// This creates a proof for the provided `circuit` when given the public
Expand Down Expand Up @@ -91,6 +71,60 @@ where
)
}

/// This creates a proof for the provided `circuit` when given the public
/// parameters `params` and the proving key [`ProvingKey`] that was
/// generated previously for the same circuit. The provided `instances`
/// are zero-padded internally.
/// In addition, this needs the `compress_selectors` field.
#[allow(clippy::too_many_arguments)]
pub fn create_proof_custom_with_engine<
'params,
Scheme: CommitmentScheme,
P: commitment::Prover<'params, Scheme>,
E: EncodedChallenge<Scheme::Curve>,
R: RngCore,
T: TranscriptWrite<Scheme::Curve, E>,
ConcreteCircuit: Circuit<Scheme::Scalar>,
M: MsmAccel<Scheme::Curve>,
>(
engine: PlonkEngine<Scheme::Curve, M>,
params: &'params Scheme::ParamsProver,
pk: &ProvingKey<Scheme::Curve>,
compress_selectors: bool,
circuits: &[ConcreteCircuit],
instances: &[&[&[Scheme::Scalar]]],
rng: R,
transcript: &mut T,
) -> Result<(), Error>
where
Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64>,
{
if circuits.len() != instances.len() {
return Err(Error::Backend(ErrorBack::InvalidInstances));
}

let (_, config, cs) =
compile_circuit::<_, ConcreteCircuit>(params.k(), &circuits[0], compress_selectors)?;
let mut witness_calcs: Vec<_> = circuits
.iter()
.enumerate()
.map(|(i, circuit)| WitnessCalculator::new(params.k(), circuit, &config, &cs, instances[i]))
.collect();
let mut prover = Prover::<Scheme, P, _, _, _, _>::new_with_engine(
engine, params, pk, instances, rng, transcript,
)?;
let mut challenges = HashMap::new();
let phases = prover.phases().to_vec();
for phase in phases.iter() {
let mut witnesses = Vec::with_capacity(circuits.len());
for witness_calc in witness_calcs.iter_mut() {
witnesses.push(witness_calc.calc(*phase, &challenges)?);
}
challenges = prover.commit_phase(*phase, witnesses).unwrap();
}
Ok(prover.create_proof()?)
}

#[test]
fn test_create_proof() {
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion halo2_proofs/tests/frontend_backend_split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ fn test_mycircuit_full_split() {
.build();
let k = K;
let circuit: MyCircuit<Fr, WIDTH_FACTOR> = MyCircuit::new(k, 42);
let (compiled_circuit, config, cs) = compile_circuit(k, &circuit).unwrap();
let (compiled_circuit, config, cs) = compile_circuit(k, &circuit, false).unwrap();

// Setup
let mut rng = BlockRng::new(OneNg {});
Expand Down

0 comments on commit bbc773b

Please sign in to comment.