diff --git a/src/execute.rs b/src/execute.rs index 9cca77dc0..da4a3e7bf 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -188,7 +188,7 @@ pub async fn run(cli: Cli) -> Result<(), Box> { } /// helper function -fn verify_proof_circuit_kzg< +pub fn verify_proof_circuit_kzg< 'params, Strategy: VerificationStrategy<'params, KZGCommitmentScheme, VerifierGWC<'params, Bn256>>, >( diff --git a/src/python.rs b/src/python.rs index f786834d3..08cddb887 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1,13 +1,21 @@ use crate::circuit::CheckMode; use crate::commands::{RunArgs, StrategyType, TranscriptType}; -use crate::execute::{create_proof_circuit_kzg, load_params_cmd}; -use crate::graph::{quantize_float, Mode, Model, ModelCircuit, VarVisibility}; -use crate::pfsys::{gen_srs as ezkl_gen_srs, create_keys, prepare_data, save_params, save_vk}; +use crate::execute::{create_proof_circuit_kzg, load_params_cmd, verify_proof_circuit_kzg}; +use crate::graph::{quantize_float, Mode, Model, ModelCircuit, ModelParams, VarVisibility}; +use crate::pfsys::{ + gen_srs as ezkl_gen_srs, + create_keys, + prepare_data, + save_params, + save_vk, + load_vk, + Snark +}; use halo2_proofs::poly::kzg::{ commitment::KZGCommitmentScheme, strategy::{AccumulatorStrategy, SingleStrategy as KZGSingleStrategy}, }; -use halo2_proofs::dev::MockProver; +use halo2_proofs::{dev::MockProver, poly::commitment::ParamsProver}; use halo2curves::bn256::{Bn256, Fr}; use log::trace; use pyo3::exceptions::{PyIOError, PyRuntimeError}; @@ -228,6 +236,7 @@ fn mock( vk_path, proof_path, params_path, + circuit_params_path, transcript, strategy, py_run_args = None @@ -238,6 +247,7 @@ fn prove( vk_path: PathBuf, proof_path: PathBuf, params_path: PathBuf, + circuit_params_path: PathBuf, transcript: TranscriptType, strategy: StrategyType, py_run_args: Option @@ -264,6 +274,8 @@ fn prove( let proving_key = create_keys::, Fr, ModelCircuit>(&circuit, ¶ms) .map_err(|_| PyRuntimeError::new_err("Failed to create proving key"))?; + let circuit_params = circuit.params.clone(); + let snark = match strategy { StrategyType::Single => { let strategy = KZGSingleStrategy::new(¶ms); @@ -297,14 +309,51 @@ fn prove( } }; - match snark?.save(&proof_path) { - Ok(_) => { - match save_vk::>(&vk_path, proving_key.get_vk()) { - Ok(_) => Ok(true), - Err(_) => Err(PyIOError::new_err("Failed to save vk to vk_path")) - } - } - Err(_) => Err(PyIOError::new_err("Failed to save to proof path")) + // save the snark proof + snark?.save(&proof_path) + .map_err(|_| PyIOError::new_err("Failed to save proof to proof path"))?; + + // save the verifier key + save_vk::>(&vk_path, proving_key.get_vk()) + .map_err(|_| PyIOError::new_err("Failed to save verifier key to vk_path"))?; + + // save the circuit + circuit_params.save(&circuit_params_path); + + Ok(true) +} + +/// verifies a given proof +#[pyfunction(signature = ( + proof_path, + circuit_params_path, + vk_path, + params_path, + transcript, + py_run_args = None +))] +fn verify( + proof_path: PathBuf, + circuit_params_path: PathBuf, + vk_path: PathBuf, + params_path: PathBuf, + transcript: TranscriptType, + py_run_args: Option, +) -> Result { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + let logrows = run_args.logrows; + let params = load_params_cmd(params_path, logrows) + .map_err(|_| PyIOError::new_err("Failed to load params"))?; + let proof = Snark::load::>(&proof_path, None, None) + .map_err(|_| PyIOError::new_err("Failed to load proof"))?; + let model_circuit_params = ModelParams::load(&circuit_params_path); + let strategy = KZGSingleStrategy::new(params.verifier_params()); + let vk = load_vk::, Fr, ModelCircuit>(vk_path, model_circuit_params) + .map_err(|_| PyIOError::new_err("Failed to load verifier key"))?; + let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, transcript, strategy); + match result { + Ok(_) => Ok(true), + Err(_) => Ok(false), } } @@ -313,7 +362,6 @@ fn prove( // TODO: CreateEVMVerifierAggr // TODO: DeployVerifierEVM // TODO: SendProofEVM -// TODO: Verify // TODO: VerifyAggr // TODO: VerifyEVM // TODO: PrintProofHex @@ -328,6 +376,7 @@ fn ezkl_lib(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(forward, m)?)?; m.add_function(wrap_pyfunction!(mock, m)?)?; m.add_function(wrap_pyfunction!(prove, m)?)?; + m.add_function(wrap_pyfunction!(verify, m)?)?; Ok(()) } diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index 351add057..e7ab1aa87 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -128,6 +128,7 @@ def test_prove(): vk_path = os.path.join(folder_path, 'test.vk') proof_path = os.path.join(folder_path, 'test.pf') + circuit_params_path = os.path.join(folder_path, 'circuit.params') res = ezkl_lib.prove( data_path, @@ -135,9 +136,29 @@ def test_prove(): vk_path, proof_path, params_path, + circuit_params_path, "poseidon", "single", ) assert res == True assert os.path.isfile(vk_path) assert os.path.isfile(proof_path) + + +def test_verify(): + """ + Test for verify + """ + + vk_path = os.path.join(folder_path, 'test.vk') + proof_path = os.path.join(folder_path, 'test.pf') + circuit_params_path = os.path.join(folder_path, 'circuit.params') + + res = ezkl_lib.verify( + proof_path, + circuit_params_path, + vk_path, + params_path, + "poseidon", + ) + assert res == True