diff --git a/src/commands.rs b/src/commands.rs index 3a717ab70..3310be709 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -336,6 +336,9 @@ pub enum Commands { /// max logrows to use for calibration, 26 is the max public SRS size #[arg(long)] max_logrows: Option, + // whether to fix the div_rebasing value truthiness during calibration. this changes how we rebase + #[arg(long)] + div_rebasing: Option, }, /// Generates a dummy SRS diff --git a/src/execute.rs b/src/execute.rs index 11f1903c2..38950ff2f 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -178,6 +178,7 @@ pub async fn run(command: Commands) -> Result> { scales, scale_rebase_multiplier, max_logrows, + div_rebasing, } => calibrate( model, data, @@ -186,6 +187,7 @@ pub async fn run(command: Commands) -> Result> { lookup_safety_margin, scales, scale_rebase_multiplier, + div_rebasing, max_logrows, ) .map(|e| serde_json::to_string(&e).unwrap()), @@ -782,6 +784,7 @@ pub(crate) fn calibrate( lookup_safety_margin: i128, scales: Option>, scale_rebase_multiplier: Vec, + div_rebasing: Option, max_logrows: Option, ) -> Result> { use std::collections::HashMap; @@ -825,6 +828,12 @@ pub(crate) fn calibrate( } }; + let div_rebasing = if let Some(div_rebasing) = div_rebasing { + vec![div_rebasing] + } else { + vec![true, false] + }; + let mut found_params: Vec = vec![]; // 2 x 2 grid @@ -862,15 +871,21 @@ pub(crate) fn calibrate( .map(|(a, b)| (*a, *b)) .collect::>(); + let range_grid = range_grid + .iter() + .cartesian_product(div_rebasing.iter()) + .map(|(a, b)| (*a, *b)) + .collect::>(); + let mut forward_pass_res = HashMap::new(); let pb = init_bar(range_grid.len() as u64); pb.set_message("calibrating..."); - for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid { + for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid { pb.set_message(format!( - "input scale: {}, param scale: {}, scale rebase multiplier: {}", - input_scale, param_scale, scale_rebase_multiplier + "input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}", + input_scale, param_scale, scale_rebase_multiplier, div_rebasing )); #[cfg(unix)] @@ -890,6 +905,7 @@ pub(crate) fn calibrate( input_scale, param_scale, scale_rebase_multiplier, + div_rebasing, ..settings.run_args.clone() }; @@ -964,6 +980,7 @@ pub(crate) fn calibrate( let found_run_args = RunArgs { input_scale: new_settings.run_args.input_scale, param_scale: new_settings.run_args.param_scale, + div_rebasing: new_settings.run_args.div_rebasing, lookup_range: new_settings.run_args.lookup_range, logrows: new_settings.run_args.logrows, scale_rebase_multiplier: new_settings.run_args.scale_rebase_multiplier, diff --git a/src/python.rs b/src/python.rs index 474fc4d85..fb64cded7 100644 --- a/src/python.rs +++ b/src/python.rs @@ -521,6 +521,7 @@ fn gen_settings( scales = None, scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(), max_logrows = None, + div_rebasing = None, ))] fn calibrate_settings( data: PathBuf, @@ -531,6 +532,7 @@ fn calibrate_settings( scales: Option>, scale_rebase_multiplier: Vec, max_logrows: Option, + div_rebasing: Option, ) -> Result { crate::execute::calibrate( model, @@ -540,6 +542,7 @@ fn calibrate_settings( lookup_safety_margin, scales, scale_rebase_multiplier, + div_rebasing, max_logrows, ) .map_err(|e| { diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 24acb8909..f9b2e8342 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -836,10 +836,10 @@ mod native_tests { let test_dir = TempDir::new(test).unwrap(); env_logger::init(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, Some(vec![0,1]), true, "single"); + kzg_prove_and_verify(path, test.to_string(), "safe", "private", "private", "public", 1, None, true, "single"); #[cfg(not(feature = "icicle"))] run_js_tests(path, test.to_string(), "testWasm"); - test_dir.close().unwrap(); + // test_dir.close().unwrap(); } #(#[test_case(WASM_TESTS[N])])* @@ -849,7 +849,7 @@ mod native_tests { let test_dir = TempDir::new(test).unwrap(); env_logger::init(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, Some(vec![0,1]), true, "single"); + kzg_prove_and_verify(path, test.to_string(), "safe", "private", "fixed", "public", 1, None, true, "single"); #[cfg(not(feature = "icicle"))] run_js_tests(path, test.to_string(), "testWasm"); test_dir.close().unwrap(); @@ -865,7 +865,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, Some(vec![0,6]), false, "single"); + kzg_prove_and_verify(path, test.to_string(), "unsafe", "private", "fixed", "public", 1, None, false, "single"); test_dir.close().unwrap(); } @@ -875,7 +875,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", Some(vec![0,6])); + mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None); test_dir.close().unwrap(); } });