Skip to content

Commit

Permalink
Better handling of EC point at infinity (#44)
Browse files Browse the repository at this point in the history
* feat: allow `msm_par` to return identity point

* feat: handle point at infinity

`multi_scalar_multiply` and `multi_exp_par` now handle point at infinity
completely

Add docs for `ec_add_unequal, ec_sub_unequal, ec_double_and_add_unequal`
to specify point at infinity leads to undefined behavior
  • Loading branch information
jonathanpwang committed May 23, 2023
1 parent 8e9032c commit 2c276b4
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 13 deletions.
62 changes: 56 additions & 6 deletions halo2-ecc/src/ecc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ impl<F: PrimeField, FC: FieldChip<F>> From<ComparableEcPoint<F, FC>>
/// If `is_strict = true`, then this function constrains that `P.x != Q.x`.
/// If you are calling this with `is_strict = false`, you must ensure that `P.x != Q.x` by some external logic (such
/// as a mathematical theorem).
///
/// # Assumptions
/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise)
pub fn ec_add_unequal<F: PrimeField, FC: FieldChip<F>>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -208,6 +211,9 @@ fn check_points_are_unequal<F: PrimeField, FC: FieldChip<F>>(
/// If `is_strict = true`, then this function constrains that `P.x != Q.x`.
/// If you are calling this with `is_strict = false`, you must ensure that `P.x != Q.x` by some external logic (such
/// as a mathematical theorem).
///
/// # Assumptions
/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise)
pub fn ec_sub_unequal<F: PrimeField, FC: FieldChip<F>>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -242,6 +248,31 @@ pub fn ec_sub_unequal<F: PrimeField, FC: FieldChip<F>>(
EcPoint::new(x_3, y_3)
}

/// Constrains `P != -Q` but allows `P == Q`, in which case output is (0,0).
/// For Weierstrass curves only.
pub fn ec_sub_strict<F: PrimeField, FC: FieldChip<F>>(
chip: &FC,
ctx: &mut Context<F>,
P: impl Into<EcPoint<F, FC::FieldPoint>>,
Q: impl Into<EcPoint<F, FC::FieldPoint>>,
) -> EcPoint<F, FC::FieldPoint>
where
FC: Selectable<F, FC::FieldPoint>,
{
let P = P.into();
let Q = Q.into();
// Compute curr_point - start_point, allowing for output to be identity point
let x_is_eq = chip.is_equal(ctx, P.x(), Q.x());
let y_is_eq = chip.is_equal(ctx, P.y(), Q.y());
let is_identity = chip.gate().and(ctx, x_is_eq, y_is_eq);
// we ONLY allow x_is_eq = true if y_is_eq is also true; this constrains P != -Q
ctx.constrain_equal(&x_is_eq, &is_identity);

let out = ec_sub_unequal(chip, ctx, P, Q, false);
let zero = chip.load_constant(ctx, FC::FieldType::zero());
ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity)
}

// Implements:
// computing 2P on elliptic curve E for P = (x, y)
// formula from https://crypto.stanford.edu/pbc/notes/elliptic/explicit.html
Expand All @@ -254,6 +285,9 @@ pub fn ec_sub_unequal<F: PrimeField, FC: FieldChip<F>>(
// we precompute lambda and constrain (2y) * lambda = 3 x^2 (mod p)
// then we compute x_3 = lambda^2 - 2 x (mod p)
// y_3 = lambda (x - x_3) - y (mod p)
/// # Assumptions
/// * `P.y != 0`
/// * `P` is not the point at infinity
pub fn ec_double<F: PrimeField, FC: FieldChip<F>>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -290,6 +324,9 @@ pub fn ec_double<F: PrimeField, FC: FieldChip<F>>(
// lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0)
// x_res = lambda_1^2 - x_0 - x_2
// y_res = lambda_1 * (x_res - x_0) - y_0
///
/// # Assumptions
/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise)
pub fn ec_double_and_add_unequal<F: PrimeField, FC: FieldChip<F>>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -426,14 +463,16 @@ where
StrictEcPoint::new(x, y)
}

// computes [scalar] * P on y^2 = x^3 + b
// - `scalar` is represented as a reference array of `AssignedCell`s
// computes [scalar] * P on short Weierstrass curve `y^2 = x^3 + b`
// - `scalar` is represented as a reference array of `AssignedValue`s
// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
// assumes:
// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits)
// - `max_bits <= modulus::<F>.bits()`
// * P has order given by the scalar field modulus
/// # Assumptions
/// * `P` is not the point at infinity
/// * `scalar` is less than the order of `P`
/// * `scalar_i < 2^{max_bits} for all i`
/// * `max_bits <= modulus::<F>.bits()`, and equality only allowed when the order of `P` equals the modulus of `F`
pub fn scalar_multiply<F: PrimeField, FC>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -578,6 +617,16 @@ where
// Input:
// - `scalars` is vector of same length as `P`
// - each `scalar` in `scalars` satisfies same assumptions as in `scalar_multiply` above

/// # Assumptions
/// * `points.len() == scalars.len()`
/// * `scalars[i].len() == scalars[j].len()` for all `i, j`
/// * `scalars[i]` is less than the order of `P`
/// * `scalars[i][j] < 2^{max_bits} for all j`
/// * `max_bits <= modulus::<F>.bits()`, and equality only allowed when the order of `P` equals the modulus of `F`
/// * `points` are all on the curve or the point at infinity
/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point)
/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point
pub fn multi_scalar_multiply<F: PrimeField, FC, C>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -688,7 +737,7 @@ where
curr_point = ec_add_unequal(chip, ctx, curr_point, add_point, true);
}
}
ec_sub_unequal(chip, ctx, curr_point, start_point, true)
ec_sub_strict(chip, ctx, curr_point, start_point)
}

pub fn get_naf(mut exp: Vec<u64>) -> Vec<i8> {
Expand Down Expand Up @@ -965,6 +1014,7 @@ where
}

// default for most purposes
/// See [`pippenger::multi_exp_par`] for more details.
pub fn variable_base_msm<C>(
&self,
thread_pool: &mut GateThreadBuilder<F>,
Expand Down
20 changes: 13 additions & 7 deletions halo2-ecc/src/ecc/pippenger.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use super::{
ec_add_unequal, ec_double, ec_select, ec_sub_unequal, into_strict_point, load_random_point,
strict_ec_select_from_bits, EcPoint, StrictEcPoint,
strict_ec_select_from_bits, EcPoint,
};
use crate::{
ecc::ec_sub_strict,
fields::{FieldChip, PrimeField, Selectable},
};
use crate::fields::{FieldChip, PrimeField, Selectable};
use halo2_base::{
gates::{builder::GateThreadBuilder, GateInstructions},
utils::CurveAffineExt,
Expand Down Expand Up @@ -64,6 +67,7 @@ where
}
*/

/* Left as reference; should always use msm_par
// Given points[i] and bool_scalars[j][i],
// compute G'[j] = sum_{i=0..points.len()} points[i] * bool_scalars[j][i]
// output is [ G'[j] + rand_point ]_{j=0..bool_scalars.len()}, rand_point
Expand Down Expand Up @@ -200,15 +204,17 @@ where
ec_sub_unequal(chip, ctx, sum, any_sum, true)
}
*/

/// Multi-thread witness generation for multi-scalar multiplication.
/// Should give exact same circuit as `multi_exp`.
///
/// Currently does not support if the final answer is actually the point at infinity (meaning constraints will fail in that case)
///
/// # Assumptions
/// * `points.len() == scalars.len()`
/// * `scalars[i].len() == scalars[j].len()` for all `i, j`
/// * `points` are all on the curve or the point at infinity
/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point)
/// * `2^max_scalar_bits != +-1 mod modulus::<F>()` where `max_scalar_bits = max_scalar_bits_per_cell * scalars[0].len()`
/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point
pub fn multi_exp_par<F: PrimeField, FC, C>(
chip: &FC,
// these are the "threads" within a single Phase
Expand All @@ -226,7 +232,7 @@ where
{
// let (points, bool_scalars) = decompose::<F, _>(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix);

debug_assert_eq!(points.len(), scalars.len());
assert_eq!(points.len(), scalars.len());
let scalar_bits = max_scalar_bits_per_cell * scalars[0].len();
// bool_scalars: 2d array `scalar_bits` by `points.len()`
let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits];
Expand Down Expand Up @@ -348,5 +354,5 @@ where
// assume 2^scalar_bits != +-1 mod modulus::<F>()
any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false);

ec_sub_unequal(chip, ctx, sum, any_sum, true)
ec_sub_strict(chip, ctx, sum, any_sum)
}

0 comments on commit 2c276b4

Please sign in to comment.