Skip to content

Commit

Permalink
make ecc tests generic (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
Leo authored May 10, 2023
1 parent cddd707 commit f16fa1e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 34 deletions.
104 changes: 72 additions & 32 deletions src/gadgets/ecc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,11 +755,8 @@ mod tests {
{shape_cs::ShapeCS, solver::SatisfyingAssignment},
};
use ff::{Field, PrimeFieldBits};
use pasta_curves::{arithmetic::CurveAffine, group::Curve, EpAffine};
use pasta_curves::{arithmetic::CurveAffine, group::Curve, pallas, vesta};
use rand::rngs::OsRng;
use std::ops::Mul;
type G1 = pasta_curves::pallas::Point;
type G2 = pasta_curves::vesta::Point;

#[derive(Debug, Clone)]
pub struct Point<G>
Expand All @@ -783,7 +780,7 @@ mod tests {
pub fn random_vartime() -> Self {
loop {
let x = G::Base::random(&mut OsRng);
let y = (x * x * x + G::Base::from(5)).sqrt();
let y = (x.square() * x + G::get_curve_params().1).sqrt();
if y.is_some().unwrap_u8() == 1 {
return Self {
x,
Expand Down Expand Up @@ -897,52 +894,61 @@ mod tests {

#[test]
fn test_ecc_ops() {
test_ecc_ops_with::<pallas::Affine, pallas::Point>();
test_ecc_ops_with::<vesta::Affine, vesta::Point>();
}

fn test_ecc_ops_with<C, G>()
where
C: CurveAffine<Base = G::Base, ScalarExt = G::Scalar>,
G: Group,
{
// perform some curve arithmetic
let a = Point::<G1>::random_vartime();
let b = Point::<G1>::random_vartime();
let a = Point::<G>::random_vartime();
let b = Point::<G>::random_vartime();
let c = a.add(&b);
let d = a.double();
let s = <G1 as Group>::Scalar::random(&mut OsRng);
let s = <G as Group>::Scalar::random(&mut OsRng);
let e = a.scalar_mul(&s);

// perform the same computation by translating to pasta_curve types
let a_pasta = EpAffine::from_xy(
pasta_curves::Fp::from_repr(a.x.to_repr()).unwrap(),
pasta_curves::Fp::from_repr(a.y.to_repr()).unwrap(),
// perform the same computation by translating to curve types
let a_curve = C::from_xy(
C::Base::from_repr(a.x.to_repr()).unwrap(),
C::Base::from_repr(a.y.to_repr()).unwrap(),
)
.unwrap();
let b_pasta = EpAffine::from_xy(
pasta_curves::Fp::from_repr(b.x.to_repr()).unwrap(),
pasta_curves::Fp::from_repr(b.y.to_repr()).unwrap(),
let b_curve = C::from_xy(
C::Base::from_repr(b.x.to_repr()).unwrap(),
C::Base::from_repr(b.y.to_repr()).unwrap(),
)
.unwrap();
let c_pasta = (a_pasta + b_pasta).to_affine();
let d_pasta = (a_pasta + a_pasta).to_affine();
let e_pasta = a_pasta
.mul(pasta_curves::Fq::from_repr(s.to_repr()).unwrap())
let c_curve = (a_curve + b_curve).to_affine();
let d_curve = (a_curve + a_curve).to_affine();
let e_curve = a_curve
.mul(C::Scalar::from_repr(s.to_repr()).unwrap())
.to_affine();

// transform c, d, and e into pasta_curve types
let c_pasta_2 = EpAffine::from_xy(
pasta_curves::Fp::from_repr(c.x.to_repr()).unwrap(),
pasta_curves::Fp::from_repr(c.y.to_repr()).unwrap(),
// transform c, d, and e into curve types
let c_curve_2 = C::from_xy(
C::Base::from_repr(c.x.to_repr()).unwrap(),
C::Base::from_repr(c.y.to_repr()).unwrap(),
)
.unwrap();
let d_pasta_2 = EpAffine::from_xy(
pasta_curves::Fp::from_repr(d.x.to_repr()).unwrap(),
pasta_curves::Fp::from_repr(d.y.to_repr()).unwrap(),
let d_curve_2 = C::from_xy(
C::Base::from_repr(d.x.to_repr()).unwrap(),
C::Base::from_repr(d.y.to_repr()).unwrap(),
)
.unwrap();
let e_pasta_2 = EpAffine::from_xy(
pasta_curves::Fp::from_repr(e.x.to_repr()).unwrap(),
pasta_curves::Fp::from_repr(e.y.to_repr()).unwrap(),
let e_curve_2 = C::from_xy(
C::Base::from_repr(e.x.to_repr()).unwrap(),
C::Base::from_repr(e.y.to_repr()).unwrap(),
)
.unwrap();

// check that we have the same outputs
assert_eq!(c_pasta, c_pasta_2);
assert_eq!(d_pasta, d_pasta_2);
assert_eq!(e_pasta, e_pasta_2);
assert_eq!(c_curve, c_curve_2);
assert_eq!(d_curve, d_curve_2);
assert_eq!(e_curve, e_curve_2);
}

fn synthesize_smul<G, CS>(mut cs: CS) -> (AllocatedPoint<G>, AllocatedPoint<G>, G::Scalar)
Expand All @@ -969,6 +975,17 @@ mod tests {

#[test]
fn test_ecc_circuit_ops() {
test_ecc_circuit_ops_with::<pallas::Base, pallas::Scalar, pallas::Point, vesta::Point>();
test_ecc_circuit_ops_with::<vesta::Base, vesta::Scalar, vesta::Point, pallas::Point>();
}

fn test_ecc_circuit_ops_with<B, S, G1, G2>()
where
B: PrimeField,
S: PrimeField,
G1: Group<Base = B, Scalar = S>,
G2: Group<Base = S, Scalar = B>,
{
// First create the shape
let mut cs: ShapeCS<G2> = ShapeCS::new();
let _ = synthesize_smul::<G1, _>(cs.namespace(|| "synthesize"));
Expand Down Expand Up @@ -1010,6 +1027,17 @@ mod tests {

#[test]
fn test_ecc_circuit_add_equal() {
test_ecc_circuit_add_equal_with::<pallas::Base, pallas::Scalar, pallas::Point, vesta::Point>();
test_ecc_circuit_add_equal_with::<vesta::Base, vesta::Scalar, vesta::Point, pallas::Point>();
}

fn test_ecc_circuit_add_equal_with<B, S, G1, G2>()
where
B: PrimeField,
S: PrimeField,
G1: Group<Base = B, Scalar = S>,
G2: Group<Base = S, Scalar = B>,
{
// First create the shape
let mut cs: ShapeCS<G2> = ShapeCS::new();
let _ = synthesize_add_equal::<G1, _>(cs.namespace(|| "synthesize add equal"));
Expand Down Expand Up @@ -1055,6 +1083,18 @@ mod tests {

#[test]
fn test_ecc_circuit_add_negation() {
test_ecc_circuit_add_negation_with::<pallas::Base, pallas::Scalar, pallas::Point, vesta::Point>(
);
test_ecc_circuit_add_negation_with::<vesta::Base, vesta::Scalar, vesta::Point, pallas::Point>();
}

fn test_ecc_circuit_add_negation_with<B, S, G1, G2>()
where
B: PrimeField,
S: PrimeField,
G1: Group<Base = B, Scalar = S>,
G2: Group<Base = S, Scalar = B>,
{
// First create the shape
let mut cs: ShapeCS<G2> = ShapeCS::new();
let _ = synthesize_add_negation::<G1, _>(cs.namespace(|| "synthesize add equal"));
Expand Down
4 changes: 2 additions & 2 deletions src/provider/pasta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ macro_rules! impl_traits {
}

fn get_curve_params() -> (Self::Base, Self::Base, BigInt) {
let A = Self::Base::zero();
let B = Self::Base::from(5);
let A = $name::Point::a();
let B = $name::Point::b();
let order = BigInt::from_str_radix($order_str, 16).unwrap();

(A, B, order)
Expand Down

0 comments on commit f16fa1e

Please sign in to comment.