diff --git a/Cargo.toml b/Cargo.toml index cea29a8a..6d30863d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2curves" -version = "0.2.1" +version = "0.3.1" authors = [ "Sean Bowe ", "Jack Grigg ", diff --git a/README.md b/README.md index 8ae82763..bf431dac 100644 --- a/README.md +++ b/README.md @@ -14,12 +14,12 @@ This implementation is mostly ported from [matterlabs/pairing](https://github.co ## Bench -None Assembly +No assembly ``` $ cargo test --profile bench test_field -- --nocapture ``` -Assembly +Assembly (returns rust nightly) ``` $ cargo test --profile bench test_field --features asm -- --nocapture ``` diff --git a/src/arithmetic.rs b/src/arithmetic.rs index 7fb2dbc6..388a422b 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -8,18 +8,17 @@ use subtle::{Choice, ConditionallySelectable, CtOption}; pub trait CurveAffineExt: pasta_curves::arithmetic::CurveAffine { fn batch_add( - _points: &mut [Self], - _output_indices: &[u32], - _num_points: usize, - _offset: usize, - _bases: &[Self], - _base_positions: &[u32], - ) { - unimplemented!() - } - - /// Unlike the `Coordinates` trait, this just returns the raw affine coordinantes without checking `is_on_curve` + points: &mut [Self], + output_indices: &[u32], + num_points: usize, + offset: usize, + bases: &[Self], + base_positions: &[u32], + ); + + /// Unlike the `Coordinates` trait, this just returns the raw affine coordinates without checking `is_on_curve` fn into_coordinates(self) -> (Self::Base, Self::Base) { + // fallback implementation let coordinates = self.coordinates().unwrap(); (*coordinates.x(), *coordinates.y()) } diff --git a/src/bn256/assembly.rs b/src/bn256/assembly.rs index a97a4db8..3e623426 100644 --- a/src/bn256/assembly.rs +++ b/src/bn256/assembly.rs @@ -12,311 +12,379 @@ macro_rules! assembly_field { $r2:ident, $r3:ident ) => { - use std::arch::asm; + use std::arch::asm; - impl $field { - /// Returns zero, the additive identity. - #[inline] - pub const fn zero() -> $field { - $field([0, 0, 0, 0]) - } + impl $field { + /// Returns zero, the additive identity. + #[inline] + pub const fn zero() -> $field { + $field([0, 0, 0, 0]) + } - /// Returns one, the multiplicative identity. - #[inline] - pub const fn one() -> $field { - $r - } + /// Returns one, the multiplicative identity. + #[inline] + pub const fn one() -> $field { + $r + } - fn from_u512(limbs: [u64; 8]) -> $field { - // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits - // with the higher bits multiplied by 2^256. Thus, we perform two reductions - // - // 1. the lower bits are multiplied by R^2, as normal - // 2. the upper bits are multiplied by R^2 * 2^256 = R^3 - // - // and computing their sum in the field. It remains to see that arbitrary 256-bit - // numbers can be placed into Montgomery form safely using the reduction. The - // reduction works so long as the product is less than R=2^256 multiplied by - // the modulus. This holds because for any `c` smaller than the modulus, we have - // that (2^256 - 1)*c is an acceptable product for the reduction. Therefore, the - // reduction always works so long as `c` is in the field; in this case it is either the - // constant `R2` or `R3`. - let d0 = $field([limbs[0], limbs[1], limbs[2], limbs[3]]); - let d1 = $field([limbs[4], limbs[5], limbs[6], limbs[7]]); - // Convert to Montgomery form - d0 * $r2 + d1 * $r3 - } + fn from_u512(limbs: [u64; 8]) -> $field { + // We reduce an arbitrary 512-bit number by decomposing it into two 256-bit digits + // with the higher bits multiplied by 2^256. Thus, we perform two reductions + // + // 1. the lower bits are multiplied by R^2, as normal + // 2. the upper bits are multiplied by R^2 * 2^256 = R^3 + // + // and computing their sum in the field. It remains to see that arbitrary 256-bit + // numbers can be placed into Montgomery form safely using the reduction. The + // reduction works so long as the product is less than R=2^256 multiplied by + // the modulus. This holds because for any `c` smaller than the modulus, we have + // that (2^256 - 1)*c is an acceptable product for the reduction. Therefore, the + // reduction always works so long as `c` is in the field; in this case it is either the + // constant `R2` or `R3`. + let d0 = $field([limbs[0], limbs[1], limbs[2], limbs[3]]); + let d1 = $field([limbs[4], limbs[5], limbs[6], limbs[7]]); + // Convert to Montgomery form + d0 * $r2 + d1 * $r3 + } - /// Converts from an integer represented in little endian - /// into its (congruent) `$field` representation. - pub const fn from_raw(val: [u64; 4]) -> Self { - // Multiplication - let (r0, carry) = mac(0, val[0], $r2.0[0], 0); - let (r1, carry) = mac(0, val[0], $r2.0[1], carry); - let (r2, carry) = mac(0, val[0], $r2.0[2], carry); - let (r3, r4) = mac(0, val[0], $r2.0[3], carry); - - let (r1, carry) = mac(r1, val[1], $r2.0[0], 0); - let (r2, carry) = mac(r2, val[1], $r2.0[1], carry); - let (r3, carry) = mac(r3, val[1], $r2.0[2], carry); - let (r4, r5) = mac(r4, val[1], $r2.0[3], carry); - - let (r2, carry) = mac(r2, val[2], $r2.0[0], 0); - let (r3, carry) = mac(r3, val[2], $r2.0[1], carry); - let (r4, carry) = mac(r4, val[2], $r2.0[2], carry); - let (r5, r6) = mac(r5, val[2], $r2.0[3], carry); - - let (r3, carry) = mac(r3, val[3], $r2.0[0], 0); - let (r4, carry) = mac(r4, val[3], $r2.0[1], carry); - let (r5, carry) = mac(r5, val[3], $r2.0[2], carry); - let (r6, r7) = mac(r6, val[3], $r2.0[3], carry); - - // Montgomery reduction (first part) - let k = r0.wrapping_mul($inv); - let (_, carry) = mac(r0, k, $modulus.0[0], 0); - let (r1, carry) = mac(r1, k, $modulus.0[1], carry); - let (r2, carry) = mac(r2, k, $modulus.0[2], carry); - let (r3, carry) = mac(r3, k, $modulus.0[3], carry); - let (r4, carry2) = adc(r4, 0, carry); - - let k = r1.wrapping_mul($inv); - let (_, carry) = mac(r1, k, $modulus.0[0], 0); - let (r2, carry) = mac(r2, k, $modulus.0[1], carry); - let (r3, carry) = mac(r3, k, $modulus.0[2], carry); - let (r4, carry) = mac(r4, k, $modulus.0[3], carry); - let (r5, carry2) = adc(r5, carry2, carry); - - let k = r2.wrapping_mul($inv); - let (_, carry) = mac(r2, k, $modulus.0[0], 0); - let (r3, carry) = mac(r3, k, $modulus.0[1], carry); - let (r4, carry) = mac(r4, k, $modulus.0[2], carry); - let (r5, carry) = mac(r5, k, $modulus.0[3], carry); - let (r6, carry2) = adc(r6, carry2, carry); - - let k = r3.wrapping_mul($inv); - let (_, carry) = mac(r3, k, $modulus.0[0], 0); - let (r4, carry) = mac(r4, k, $modulus.0[1], carry); - let (r5, carry) = mac(r5, k, $modulus.0[2], carry); - let (r6, carry) = mac(r6, k, $modulus.0[3], carry); - let (r7, _) = adc(r7, carry2, carry); - - // Montgomery reduction (sub part) - let (d0, borrow) = sbb(r4, $modulus.0[0], 0); - let (d1, borrow) = sbb(r5, $modulus.0[1], borrow); - let (d2, borrow) = sbb(r6, $modulus.0[2], borrow); - let (d3, borrow) = sbb(r7, $modulus.0[3], borrow); - - let (d0, carry) = adc(d0, $modulus.0[0] & borrow, 0); - let (d1, carry) = adc(d1, $modulus.0[1] & borrow, carry); - let (d2, carry) = adc(d2, $modulus.0[2] & borrow, carry); - let (d3, _) = adc(d3, $modulus.0[3] & borrow, carry); - - $field([d0, d1, d2, d3]) - } + /// Converts from an integer represented in little endian + /// into its (congruent) `$field` representation. + pub const fn from_raw(val: [u64; 4]) -> Self { + // Multiplication + let (r0, carry) = mac(0, val[0], $r2.0[0], 0); + let (r1, carry) = mac(0, val[0], $r2.0[1], carry); + let (r2, carry) = mac(0, val[0], $r2.0[2], carry); + let (r3, r4) = mac(0, val[0], $r2.0[3], carry); + + let (r1, carry) = mac(r1, val[1], $r2.0[0], 0); + let (r2, carry) = mac(r2, val[1], $r2.0[1], carry); + let (r3, carry) = mac(r3, val[1], $r2.0[2], carry); + let (r4, r5) = mac(r4, val[1], $r2.0[3], carry); + + let (r2, carry) = mac(r2, val[2], $r2.0[0], 0); + let (r3, carry) = mac(r3, val[2], $r2.0[1], carry); + let (r4, carry) = mac(r4, val[2], $r2.0[2], carry); + let (r5, r6) = mac(r5, val[2], $r2.0[3], carry); + + let (r3, carry) = mac(r3, val[3], $r2.0[0], 0); + let (r4, carry) = mac(r4, val[3], $r2.0[1], carry); + let (r5, carry) = mac(r5, val[3], $r2.0[2], carry); + let (r6, r7) = mac(r6, val[3], $r2.0[3], carry); + + // Montgomery reduction (first part) + let k = r0.wrapping_mul($inv); + let (_, carry) = mac(r0, k, $modulus.0[0], 0); + let (r1, carry) = mac(r1, k, $modulus.0[1], carry); + let (r2, carry) = mac(r2, k, $modulus.0[2], carry); + let (r3, carry) = mac(r3, k, $modulus.0[3], carry); + let (r4, carry2) = adc(r4, 0, carry); + + let k = r1.wrapping_mul($inv); + let (_, carry) = mac(r1, k, $modulus.0[0], 0); + let (r2, carry) = mac(r2, k, $modulus.0[1], carry); + let (r3, carry) = mac(r3, k, $modulus.0[2], carry); + let (r4, carry) = mac(r4, k, $modulus.0[3], carry); + let (r5, carry2) = adc(r5, carry2, carry); + + let k = r2.wrapping_mul($inv); + let (_, carry) = mac(r2, k, $modulus.0[0], 0); + let (r3, carry) = mac(r3, k, $modulus.0[1], carry); + let (r4, carry) = mac(r4, k, $modulus.0[2], carry); + let (r5, carry) = mac(r5, k, $modulus.0[3], carry); + let (r6, carry2) = adc(r6, carry2, carry); + + let k = r3.wrapping_mul($inv); + let (_, carry) = mac(r3, k, $modulus.0[0], 0); + let (r4, carry) = mac(r4, k, $modulus.0[1], carry); + let (r5, carry) = mac(r5, k, $modulus.0[2], carry); + let (r6, carry) = mac(r6, k, $modulus.0[3], carry); + let (r7, _) = adc(r7, carry2, carry); + + // Montgomery reduction (sub part) + let (d0, borrow) = sbb(r4, $modulus.0[0], 0); + let (d1, borrow) = sbb(r5, $modulus.0[1], borrow); + let (d2, borrow) = sbb(r6, $modulus.0[2], borrow); + let (d3, borrow) = sbb(r7, $modulus.0[3], borrow); + + let (d0, carry) = adc(d0, $modulus.0[0] & borrow, 0); + let (d1, carry) = adc(d1, $modulus.0[1] & borrow, carry); + let (d2, carry) = adc(d2, $modulus.0[2] & borrow, carry); + let (d3, _) = adc(d3, $modulus.0[3] & borrow, carry); + + $field([d0, d1, d2, d3]) + } - /// Attempts to convert a little-endian byte representation of - /// a scalar into a `Fr`, failing if the input is not canonical. - pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<$field> { - ::from_repr(*bytes) - } + /// Attempts to convert a little-endian byte representation of + /// a scalar into a `Fr`, failing if the input is not canonical. + pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<$field> { + ::from_repr(*bytes) + } - /// Converts an element of `Fr` into a byte representation in - /// little-endian byte order. - pub fn to_bytes(&self) -> [u8; 32] { - ::to_repr(self) - } + /// Converts an element of `Fr` into a byte representation in + /// little-endian byte order. + pub fn to_bytes(&self) -> [u8; 32] { + ::to_repr(self) } + } - impl Group for $field { - type Scalar = Self; + impl Group for $field { + type Scalar = Self; - fn group_zero() -> Self { - Self::zero() - } - fn group_add(&mut self, rhs: &Self) { - *self += *rhs; - } - fn group_sub(&mut self, rhs: &Self) { - *self -= *rhs; - } - fn group_scale(&mut self, by: &Self::Scalar) { - *self *= *by; - } + fn group_zero() -> Self { + Self::zero() } - - impl fmt::Debug for $field { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let tmp = self.to_repr(); - write!(f, "0x")?; - for &b in tmp.iter().rev() { - write!(f, "{:02x}", b)?; - } - Ok(()) - } + fn group_add(&mut self, rhs: &Self) { + *self += *rhs; + } + fn group_sub(&mut self, rhs: &Self) { + *self -= *rhs; + } + fn group_scale(&mut self, by: &Self::Scalar) { + *self *= *by; } + } - impl Default for $field { - #[inline] - fn default() -> Self { - Self::zero() + impl fmt::Debug for $field { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let tmp = self.to_repr(); + write!(f, "0x")?; + for &b in tmp.iter().rev() { + write!(f, "{:02x}", b)?; } + Ok(()) } + } - impl From for $field { - fn from(bit: bool) -> $field { - if bit { - $field::one() - } else { - $field::zero() - } - } + impl Default for $field { + #[inline] + fn default() -> Self { + Self::zero() } + } - impl From for $field { - fn from(val: u64) -> $field { - $field([val, 0, 0, 0]) * $r2 + impl From for $field { + fn from(bit: bool) -> $field { + if bit { + $field::one() + } else { + $field::zero() } } + } - impl ConstantTimeEq for $field { - fn ct_eq(&self, other: &Self) -> Choice { - self.0[0].ct_eq(&other.0[0]) - & self.0[1].ct_eq(&other.0[1]) - & self.0[2].ct_eq(&other.0[2]) - & self.0[3].ct_eq(&other.0[3]) - } + impl From for $field { + fn from(val: u64) -> $field { + $field([val, 0, 0, 0]) * $r2 } + } - impl PartialEq for $field { - #[inline] - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } + impl ConstantTimeEq for $field { + fn ct_eq(&self, other: &Self) -> Choice { + self.0[0].ct_eq(&other.0[0]) + & self.0[1].ct_eq(&other.0[1]) + & self.0[2].ct_eq(&other.0[2]) + & self.0[3].ct_eq(&other.0[3]) } + } - impl core::cmp::Ord for $field { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - let left = self.to_repr(); - let right = other.to_repr(); - left.iter() - .zip(right.iter()) - .rev() - .find_map(|(left_byte, right_byte)| match left_byte.cmp(right_byte) { - core::cmp::Ordering::Equal => None, - res => Some(res), - }) - .unwrap_or(core::cmp::Ordering::Equal) - } + impl core::cmp::Ord for $field { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + let left = self.to_repr(); + let right = other.to_repr(); + left.iter() + .zip(right.iter()) + .rev() + .find_map(|(left_byte, right_byte)| match left_byte.cmp(right_byte) { + core::cmp::Ordering::Equal => None, + res => Some(res), + }) + .unwrap_or(core::cmp::Ordering::Equal) } + } - impl core::cmp::PartialOrd for $field { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } + impl core::cmp::PartialOrd for $field { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } + } - impl ConditionallySelectable for $field { - fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { - $field([ - u64::conditional_select(&a.0[0], &b.0[0], choice), - u64::conditional_select(&a.0[1], &b.0[1], choice), - u64::conditional_select(&a.0[2], &b.0[2], choice), - u64::conditional_select(&a.0[3], &b.0[3], choice), - ]) - } + impl ConditionallySelectable for $field { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + $field([ + u64::conditional_select(&a.0[0], &b.0[0], choice), + u64::conditional_select(&a.0[1], &b.0[1], choice), + u64::conditional_select(&a.0[2], &b.0[2], choice), + u64::conditional_select(&a.0[3], &b.0[3], choice), + ]) } + } - impl<'a> Neg for &'a $field { - type Output = $field; + impl<'a> Neg for &'a $field { + type Output = $field; - #[inline] - fn neg(self) -> $field { - self.neg() - } + #[inline] + fn neg(self) -> $field { + self.neg() } + } - impl Neg for $field { - type Output = $field; + impl Neg for $field { + type Output = $field; - #[inline] - fn neg(self) -> $field { - -&self - } + #[inline] + fn neg(self) -> $field { + -&self } + } - impl<'a, 'b> Sub<&'b $field> for &'a $field { - type Output = $field; + impl<'a, 'b> Sub<&'b $field> for &'a $field { + type Output = $field; - #[inline] - fn sub(self, rhs: &'b $field) -> $field { - self.sub(rhs) - } + #[inline] + fn sub(self, rhs: &'b $field) -> $field { + self.sub(rhs) } + } - impl<'a, 'b> Add<&'b $field> for &'a $field { - type Output = $field; + impl<'a, 'b> Add<&'b $field> for &'a $field { + type Output = $field; - #[inline] - fn add(self, rhs: &'b $field) -> $field { - self.add(rhs) - } + #[inline] + fn add(self, rhs: &'b $field) -> $field { + self.add(rhs) } + } - impl<'a, 'b> Mul<&'b $field> for &'a $field { - type Output = $field; + impl<'a, 'b> Mul<&'b $field> for &'a $field { + type Output = $field; - #[inline] - fn mul(self, rhs: &'b $field) -> $field { - self.mul(rhs) - } + #[inline] + fn mul(self, rhs: &'b $field) -> $field { + self.mul(rhs) } + } - impl From<$field> for [u8; 32] { - fn from(value: $field) -> [u8; 32] { - value.to_repr() - } + impl From<$field> for [u8; 32] { + fn from(value: $field) -> [u8; 32] { + value.to_repr() } + } - impl<'a> From<&'a $field> for [u8; 32] { - fn from(value: &'a $field) -> [u8; 32] { - value.to_repr() - } + impl<'a> From<&'a $field> for [u8; 32] { + fn from(value: &'a $field) -> [u8; 32] { + value.to_repr() } + } - impl FieldExt for $field { - const MODULUS: &'static str = $modulus_str; - const TWO_INV: Self = $two_inv; - const ROOT_OF_UNITY_INV: Self = $root_of_unity_inv; - const DELTA: Self = $delta; - const ZETA: Self = $zeta; + impl FieldExt for $field { + const MODULUS: &'static str = $modulus_str; + const TWO_INV: Self = $two_inv; + const ROOT_OF_UNITY_INV: Self = $root_of_unity_inv; + const DELTA: Self = $delta; + const ZETA: Self = $zeta; - fn from_u128(v: u128) -> Self { - $field::from_raw([v as u64, (v >> 64) as u64, 0, 0]) - } + fn from_u128(v: u128) -> Self { + $field::from_raw([v as u64, (v >> 64) as u64, 0, 0]) + } - /// Converts a 512-bit little endian integer into - /// a `$field` by reducing by the modulus. - fn from_bytes_wide(bytes: &[u8; 64]) -> $field { - $field::from_u512([ - u64::from_le_bytes(bytes[0..8].try_into().unwrap()), - u64::from_le_bytes(bytes[8..16].try_into().unwrap()), - u64::from_le_bytes(bytes[16..24].try_into().unwrap()), - u64::from_le_bytes(bytes[24..32].try_into().unwrap()), - u64::from_le_bytes(bytes[32..40].try_into().unwrap()), - u64::from_le_bytes(bytes[40..48].try_into().unwrap()), - u64::from_le_bytes(bytes[48..56].try_into().unwrap()), - u64::from_le_bytes(bytes[56..64].try_into().unwrap()), - ]) - } + /// Converts a 512-bit little endian integer into + /// a `$field` by reducing by the modulus. + fn from_bytes_wide(bytes: &[u8; 64]) -> $field { + $field::from_u512([ + u64::from_le_bytes(bytes[0..8].try_into().unwrap()), + u64::from_le_bytes(bytes[8..16].try_into().unwrap()), + u64::from_le_bytes(bytes[16..24].try_into().unwrap()), + u64::from_le_bytes(bytes[24..32].try_into().unwrap()), + u64::from_le_bytes(bytes[32..40].try_into().unwrap()), + u64::from_le_bytes(bytes[40..48].try_into().unwrap()), + u64::from_le_bytes(bytes[48..56].try_into().unwrap()), + u64::from_le_bytes(bytes[56..64].try_into().unwrap()), + ]) + } + + fn get_lower_128(&self) -> u128 { + let tmp = $field::montgomery_reduce(&[ + self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0, + ]); - fn get_lower_128(&self) -> u128 { - let tmp = $field::montgomery_reduce(&[ - self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0, - ]); + u128::from(tmp.0[0]) | (u128::from(tmp.0[1]) << 64) + } + } - u128::from(tmp.0[0]) | (u128::from(tmp.0[1]) << 64) + impl $crate::serde::SerdeObject for $field { + fn from_raw_bytes_unchecked(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), 32); + let inner = + [0, 8, 16, 24].map(|i| u64::from_le_bytes(bytes[i..i + 8].try_into().unwrap())); + Self(inner) + } + fn from_raw_bytes(bytes: &[u8]) -> Option { + if bytes.len() != 32 { + return None; + } + let elt = Self::from_raw_bytes_unchecked(bytes); + is_less_than(&elt.0, &$modulus.0).then(|| elt) + } + fn to_raw_bytes(&self) -> Vec { + let mut res = Vec::with_capacity(32); + for limb in self.0.iter() { + res.extend_from_slice(&limb.to_le_bytes()); } + res } + fn read_raw_unchecked(reader: &mut R) -> Self { + let inner = [(); 4].map(|_| { + let mut buf = [0; 8]; + reader.read_exact(&mut buf).unwrap(); + u64::from_le_bytes(buf) + }); + Self(inner) + } + fn read_raw(reader: &mut R) -> std::io::Result { + let mut inner = [0u64; 4]; + for limb in inner.iter_mut() { + let mut buf = [0; 8]; + reader.read_exact(&mut buf)?; + *limb = u64::from_le_bytes(buf); + } + let elt = Self(inner); + is_less_than(&elt.0, &$modulus.0) + .then(|| elt) + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "input number is not less than field modulus", + ) + }) + } + fn write_raw(&self, writer: &mut W) -> std::io::Result<()> { + for limb in self.0.iter() { + writer.write_all(&limb.to_le_bytes())?; + } + Ok(()) + } + } + + /// Lexicographic comparison of Montgomery forms. + #[inline(always)] + fn is_less_than(x: &[u64; 4], y: &[u64; 4]) -> bool { + match x[3].cmp(&y[3]) { + core::cmp::Ordering::Less => return true, + core::cmp::Ordering::Greater => return false, + _ => {} + } + match x[2].cmp(&y[2]) { + core::cmp::Ordering::Less => return true, + core::cmp::Ordering::Greater => return false, + _ => {} + } + match x[1].cmp(&y[1]) { + core::cmp::Ordering::Less => return true, + core::cmp::Ordering::Greater => return false, + _ => {} + } + x[0].lt(&y[0]) + } impl $field { /// Doubles this field element. diff --git a/src/bn256/curve.rs b/src/bn256/curve.rs index 890db244..0744481e 100644 --- a/src/bn256/curve.rs +++ b/src/bn256/curve.rs @@ -26,6 +26,7 @@ new_curve_impl!( G1, G1Affine, G1Compressed, + Fq::size(), Fq, Fr, (G1_GENERATOR_X,G1_GENERATOR_Y), @@ -38,6 +39,7 @@ new_curve_impl!( G2, G2Affine, G2Compressed, + Fq2::size(), Fq2, Fr, (G2_GENERATOR_X, G2_GENERATOR_Y), @@ -292,6 +294,12 @@ mod tests { assert_eq!(res_affine, exp_affine); } + + #[test] + fn test_serialization() { + crate::tests::curve::random_serialization_test::(); + crate::tests::curve::random_serialization_test::(); + } } impl group::UncompressedEncoding for G1Affine { diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index e130d9a0..7cf7b5f1 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -354,4 +354,9 @@ mod test { fn test_field() { crate::tests::field::random_field_tests::("fq".to_string()); } + + #[test] + fn test_serialization() { + crate::tests::field::random_serialization_test::("fq".to_string()); + } } diff --git a/src/bn256/fq2.rs b/src/bn256/fq2.rs index 9b45c3cd..2972c4fc 100644 --- a/src/bn256/fq2.rs +++ b/src/bn256/fq2.rs @@ -572,6 +572,41 @@ impl ff::PrimeField for Fq2 { } } +impl crate::serde::SerdeObject for Fq2 { + fn from_raw_bytes_unchecked(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), 64); + let [c0, c1] = [0, 32].map(|i| Fq::from_raw_bytes_unchecked(&bytes[i..i + 32])); + Self { c0, c1 } + } + fn from_raw_bytes(bytes: &[u8]) -> Option { + if bytes.len() != 64 { + return None; + } + let [c0, c1] = [0, 32].map(|i| Fq::from_raw_bytes(&bytes[i..i + 32])); + c0.zip(c1).map(|(c0, c1)| Self { c0, c1 }) + } + fn to_raw_bytes(&self) -> Vec { + let mut res = Vec::with_capacity(64); + for limb in self.c0.0.iter().chain(self.c1.0.iter()) { + res.extend_from_slice(&limb.to_le_bytes()); + } + res + } + fn read_raw_unchecked(reader: &mut R) -> Self { + let [c0, c1] = [(); 2].map(|_| Fq::read_raw_unchecked(reader)); + Self { c0, c1 } + } + fn read_raw(reader: &mut R) -> std::io::Result { + let c0 = Fq::read_raw(reader)?; + let c1 = Fq::read_raw(reader)?; + Ok(Self { c0, c1 }) + } + fn write_raw(&self, writer: &mut W) -> std::io::Result<()> { + self.c0.write_raw(writer)?; + self.c1.write_raw(writer) + } +} + pub const FROBENIUS_COEFF_FQ2_C1: [Fq; 2] = [ // Fq(-1)**(((q^0) - 1) / 2) // it's 1 in Montgommery form @@ -799,3 +834,8 @@ fn test_frobenius() { fn test_field() { crate::tests::field::random_field_tests::("fq2".to_string()); } + +#[test] +fn test_serialization() { + crate::tests::field::random_serialization_test::("fq2".to_string()); +} diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 520f36d4..790517ee 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -363,4 +363,9 @@ mod test { ]) ); } + + #[test] + fn test_serialization() { + crate::tests::field::random_serialization_test::("fr".to_string()); + } } diff --git a/src/derive/curve.rs b/src/derive/curve.rs index fb38ae81..a038898d 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -146,6 +146,7 @@ macro_rules! new_curve_impl { $name:ident, $name_affine:ident, $name_compressed:ident, + $compressed_size:expr, $base:ident, $scalar:ident, $generator:expr, @@ -153,7 +154,7 @@ macro_rules! new_curve_impl { $curve_id:literal, ) => { - #[derive(Copy, Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] + #[derive(Copy, Clone, Debug, Serialize, Deserialize)] $($privacy)* struct $name { pub x: $base, pub y: $base, @@ -167,8 +168,7 @@ macro_rules! new_curve_impl { } #[derive(Copy, Clone, Hash)] - $($privacy)* struct $name_compressed([u8; $base::size()]); - + $($privacy)* struct $name_compressed([u8; $compressed_size]); impl $name { pub fn generator() -> Self { @@ -232,7 +232,7 @@ macro_rules! new_curve_impl { impl Default for $name_compressed { fn default() -> Self { - $name_compressed([0; $base::size()]) + $name_compressed([0; $compressed_size]) } } @@ -302,6 +302,12 @@ macro_rules! new_curve_impl { } } + impl PartialEq for $name { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } + } + impl cmp::Eq for $name {} impl CurveExt for $name { @@ -473,6 +479,46 @@ macro_rules! new_curve_impl { } } + impl $crate::serde::SerdeObject for $name { + fn from_raw_bytes_unchecked(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), 3 * $base::size()); + let [x, y, z] = [0, 1, 2] + .map(|i| $base::from_raw_bytes_unchecked(&bytes[i * $base::size()..(i + 1) * $base::size()])); + Self { x, y, z } + } + fn from_raw_bytes(bytes: &[u8]) -> Option { + if bytes.len() != 3 * $base::size() { + return None; + } + let [x, y, z] = + [0, 1, 2].map(|i| $base::from_raw_bytes(&bytes[i * $base::size()..(i + 1) * $base::size()])); + x.zip(y).zip(z).and_then(|((x, y), z)| { + let res = Self { x, y, z }; + // Check that the point is on the curve. + bool::from(res.is_on_curve()).then(|| res) + }) + } + fn to_raw_bytes(&self) -> Vec { + let mut res = Vec::with_capacity(3 * $base::size()); + Self::write_raw(self, &mut res).unwrap(); + res + } + fn read_raw_unchecked(reader: &mut R) -> Self { + let [x, y, z] = [(); 3].map(|_| $base::read_raw_unchecked(reader)); + Self { x, y, z } + } + fn read_raw(reader: &mut R) -> std::io::Result { + let x = $base::read_raw(reader)?; + let y = $base::read_raw(reader)?; + let z = $base::read_raw(reader)?; + Ok(Self { x, y, z }) + } + fn write_raw(&self, writer: &mut W) -> std::io::Result<()> { + self.x.write_raw(writer)?; + self.y.write_raw(writer)?; + self.z.write_raw(writer) + } + } impl group::prime::PrimeGroup for $name {} @@ -557,10 +603,12 @@ macro_rules! new_curve_impl { fn from_bytes(bytes: &Self::Repr) -> CtOption { let bytes = &bytes.0; let mut tmp = *bytes; - let ysign = Choice::from(tmp[$base::size() - 1] >> 7); - tmp[$base::size() - 1] &= 0b0111_1111; + let ysign = Choice::from(tmp[$compressed_size - 1] >> 7); + tmp[$compressed_size - 1] &= 0b0111_1111; + let mut xbytes = [0u8; $base::size()]; + xbytes.copy_from_slice(&tmp[ ..$base::size()]); - $base::from_bytes(&tmp).and_then(|x| { + $base::from_bytes(&xbytes).and_then(|x| { CtOption::new(Self::identity(), x.is_zero() & (!ysign)).or_else(|| { let x3 = x.square() * x; (x3 + $name::curve_constant_b()).sqrt().and_then(|y| { @@ -590,13 +638,52 @@ macro_rules! new_curve_impl { } else { let (x, y) = (self.x, self.y); let sign = (y.to_bytes()[0] & 1) << 7; - let mut xbytes = x.to_bytes(); - xbytes[$base::size() - 1] |= sign; + let mut xbytes = [0u8; $compressed_size]; + xbytes[..$base::size()].copy_from_slice(&x.to_bytes()); + xbytes[$compressed_size - 1] |= sign; $name_compressed(xbytes) } } } + impl crate::serde::SerdeObject for $name_affine { + fn from_raw_bytes_unchecked(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), 2 * $base::size()); + let [x, y] = + [0, $base::size()].map(|i| $base::from_raw_bytes_unchecked(&bytes[i..i + $base::size()])); + Self { x, y } + } + fn from_raw_bytes(bytes: &[u8]) -> Option { + if bytes.len() != 2 * $base::size() { + return None; + } + let [x, y] = [0, $base::size()].map(|i| $base::from_raw_bytes(&bytes[i..i + $base::size()])); + x.zip(y).and_then(|(x, y)| { + let res = Self { x, y }; + // Check that the point is on the curve. + bool::from(res.is_on_curve()).then(|| res) + }) + } + fn to_raw_bytes(&self) -> Vec { + let mut res = Vec::with_capacity(2 * $base::size()); + Self::write_raw(self, &mut res).unwrap(); + res + } + fn read_raw_unchecked(reader: &mut R) -> Self { + let [x, y] = [(); 2].map(|_| $base::read_raw_unchecked(reader)); + Self { x, y } + } + fn read_raw(reader: &mut R) -> std::io::Result { + let x = $base::read_raw(reader)?; + let y = $base::read_raw(reader)?; + Ok(Self { x, y }) + } + fn write_raw(&self, writer: &mut W) -> std::io::Result<()> { + self.x.write_raw(writer)?; + self.y.write_raw(writer) + } + } + impl group::prime::PrimeCurveAffine for $name_affine { type Curve = $name; type Scalar = $scalar; diff --git a/src/derive/field.rs b/src/derive/field.rs index d1f138d9..7c42aa52 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -290,6 +290,60 @@ macro_rules! field_common { u128::from(tmp.0[0]) | (u128::from(tmp.0[1]) << 64) } } + + impl $crate::serde::SerdeObject for $field { + fn from_raw_bytes_unchecked(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), 32); + let inner = + [0, 8, 16, 24].map(|i| u64::from_le_bytes(bytes[i..i + 8].try_into().unwrap())); + Self(inner) + } + fn from_raw_bytes(bytes: &[u8]) -> Option { + if bytes.len() != 32 { + return None; + } + let elt = Self::from_raw_bytes_unchecked(bytes); + Self::is_less_than(&elt.0, &$modulus.0).then(|| elt) + } + fn to_raw_bytes(&self) -> Vec { + let mut res = Vec::with_capacity(32); + for limb in self.0.iter() { + res.extend_from_slice(&limb.to_le_bytes()); + } + res + } + fn read_raw_unchecked(reader: &mut R) -> Self { + let inner = [(); 4].map(|_| { + let mut buf = [0; 8]; + reader.read_exact(&mut buf).unwrap(); + u64::from_le_bytes(buf) + }); + Self(inner) + } + fn read_raw(reader: &mut R) -> std::io::Result { + let mut inner = [0u64; 4]; + for limb in inner.iter_mut() { + let mut buf = [0; 8]; + reader.read_exact(&mut buf)?; + *limb = u64::from_le_bytes(buf); + } + let elt = Self(inner); + Self::is_less_than(&elt.0, &$modulus.0) + .then(|| elt) + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + "input number is not less than field modulus", + ) + }) + } + fn write_raw(&self, writer: &mut W) -> std::io::Result<()> { + for limb in self.0.iter() { + writer.write_all(&limb.to_le_bytes())?; + } + Ok(()) + } + } }; } @@ -309,22 +363,22 @@ macro_rules! field_arithmetic { pub const fn const_mul(&self, rhs: &Self) -> $field { // Schoolbook multiplication - let (r0, carry) = self.0[0].widening_mul(rhs.0[0]); - let (r1, carry) = macx(carry, self.0[0], rhs.0[1]); - let (r2, carry) = macx(carry, self.0[0], rhs.0[2]); - let (r3, r4) = macx(carry, self.0[0], rhs.0[3]); + let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0); + let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry); + let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry); + let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry); - let (r1, carry) = macx(r1, self.0[1], rhs.0[0]); + let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0); let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry); let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry); let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry); - let (r2, carry) = macx(r2, self.0[2], rhs.0[0]); + let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0); let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry); let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry); let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry); - let (r3, carry) = macx(r3, self.0[3], rhs.0[0]); + let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0); let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry); let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry); let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry); @@ -454,7 +508,7 @@ macro_rules! field_arithmetic { (r2, r3) = mac(r2, k, $modulus.0[3], r3); // Result may be within MODULUS of the correct value - if !$field::is_less_than([r0, r1, r2, r3], $modulus.0) { + if !$field::is_less_than(&[r0, r1, r2, r3], &$modulus.0) { let mut borrow; (r0, borrow) = r0.overflowing_sub($modulus.0[0]); (r1, borrow) = sbb(r1, $modulus.0[1], borrow); @@ -566,7 +620,7 @@ macro_rules! field_specific { t3 = r0 + r1; // Result may be within MODULUS of the correct value - if !$field::is_less_than([t0, t1, t2, t3], $modulus.0) { + if !$field::is_less_than(&[t0, t1, t2, t3], &$modulus.0) { let mut borrow; (t0, borrow) = t0.overflowing_sub($modulus.0[0]); (t1, borrow) = sbb(t1, $modulus.0[1], borrow); diff --git a/src/lib.rs b/src/lib.rs index 175c1533..abdf1e2c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ pub mod bn256; pub mod pairing; pub mod pasta; pub mod secp256k1; +pub mod serde; #[macro_use] mod derive; diff --git a/src/secp256k1/curve.rs b/src/secp256k1/curve.rs index 4088da1b..8087cc08 100644 --- a/src/secp256k1/curve.rs +++ b/src/secp256k1/curve.rs @@ -60,6 +60,7 @@ new_curve_impl!( Secp256k1, Secp256k1Affine, Secp256k1Compressed, + 33, Fp, Fq, (SECP_GENERATOR_X,SECP_GENERATOR_Y), @@ -80,6 +81,11 @@ fn test_curve() { crate::tests::curve::curve_tests::(); } +#[test] +fn test_serialization() { + crate::tests::curve::random_serialization_test::(); +} + #[test] fn ecdsa_example() { use crate::group::Curve; diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index 25932079..e032c7ab 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -275,4 +275,9 @@ mod test { fn test_field() { crate::tests::field::random_field_tests::("secp256k1 base".to_string()); } + + #[test] + fn test_serialization() { + crate::tests::field::random_serialization_test::("secp256k1 base".to_string()); + } } diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 86947590..1df876a1 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -318,4 +318,9 @@ mod test { fn test_field() { crate::tests::field::random_field_tests::("secp256k1 scalar".to_string()); } + + #[test] + fn test_serialization() { + crate::tests::field::random_serialization_test::("secp256k1 scalar".to_string()); + } } diff --git a/src/serde.rs b/src/serde.rs new file mode 100644 index 00000000..5d0f24fc --- /dev/null +++ b/src/serde.rs @@ -0,0 +1,25 @@ +use std::io::{self, Read, Write}; + +/// Trait for converting raw bytes to/from the internal representation of a type. +/// For example, field elements are represented in Montgomery form and serialized/deserialized without Montgomery reduction. +pub trait SerdeObject: Sized { + /// The purpose of unchecked functions is to read the internal memory representation + /// of a type from bytes as quickly as possible. No sanitization checks are performed + /// to ensure the bytes represent a valid object. As such this function should only be + /// used internally as an extension of machine memory. It should not be used to deserialize + /// externally provided data. + fn from_raw_bytes_unchecked(bytes: &[u8]) -> Self; + fn from_raw_bytes(bytes: &[u8]) -> Option; + + fn to_raw_bytes(&self) -> Vec; + + /// The purpose of unchecked functions is to read the internal memory representation + /// of a type from disk as quickly as possible. No sanitization checks are performed + /// to ensure the bytes represent a valid object. This function should only be used + /// internally when some machine state cannot be kept in memory (e.g., between runs) + /// and needs to be reloaded as quickly as possible. + fn read_raw_unchecked(reader: &mut R) -> Self; + fn read_raw(reader: &mut R) -> io::Result; + + fn write_raw(&self, writer: &mut W) -> io::Result<()>; +} diff --git a/src/tests/curve.rs b/src/tests/curve.rs index 03999a1c..e6e6797f 100644 --- a/src/tests/curve.rs +++ b/src/tests/curve.rs @@ -1,3 +1,5 @@ +#![allow(clippy::eq_op)] +use crate::{group::GroupEncoding, serde::SerdeObject}; use ff::Field; use group::prime::PrimeCurveAffine; use pasta_curves::arithmetic::{CurveAffine, CurveExt}; @@ -11,6 +13,59 @@ pub fn curve_tests() { mixed_addition::(); multiplication::(); batch_normalize::(); + serdes::(); +} + +fn serdes() { + for _ in 0..100 { + let projective_point = G::random(OsRng); + let affine_point: G::AffineExt = projective_point.into(); + let projective_repr = projective_point.to_bytes(); + let affine_repr = affine_point.to_bytes(); + + println!( + "{:?} \n{:?}", + projective_repr.as_ref(), + affine_repr.as_ref() + ); + + let projective_point_rec = G::from_bytes(&projective_repr).unwrap(); + let projective_point_rec_unchecked = G::from_bytes(&projective_repr).unwrap(); + let affine_point_rec = G::AffineExt::from_bytes(&affine_repr).unwrap(); + let affine_point_rec_unchecked = G::AffineExt::from_bytes(&affine_repr).unwrap(); + + assert_eq!(projective_point, projective_point_rec); + assert_eq!(projective_point, projective_point_rec_unchecked); + assert_eq!(affine_point, affine_point_rec); + assert_eq!(affine_point, affine_point_rec_unchecked); + } +} + +pub fn random_serialization_test() +where + G: SerdeObject, + G::AffineExt: SerdeObject, +{ + for _ in 0..100 { + let projective_point = G::random(OsRng); + let affine_point: G::AffineExt = projective_point.into(); + + let projective_bytes = projective_point.to_raw_bytes(); + let projective_point_rec = G::from_raw_bytes(&projective_bytes).unwrap(); + assert_eq!(projective_point, projective_point_rec); + let mut buf = Vec::new(); + projective_point.write_raw(&mut buf).unwrap(); + let projective_point_rec = G::read_raw(&mut &buf[..]).unwrap(); + assert_eq!(projective_point, projective_point_rec); + + let affine_bytes = affine_point.to_raw_bytes(); + let affine_point_rec = G::AffineExt::from_raw_bytes(&affine_bytes).unwrap(); + assert_eq!(affine_point, affine_point_rec); + let mut buf = Vec::new(); + affine_point.write_raw(&mut buf).unwrap(); + let affine_point_rec = G::AffineExt::read_raw(&mut &buf[..]).unwrap(); + assert_eq!(affine_point, affine_point_rec); + } } fn is_on_curve() { diff --git a/src/tests/field.rs b/src/tests/field.rs index 108349f3..ab89a19c 100644 --- a/src/tests/field.rs +++ b/src/tests/field.rs @@ -3,6 +3,8 @@ use ff::Field; use rand::{RngCore, SeedableRng}; use rand_xorshift::XorShiftRng; +use crate::serde::SerdeObject; + pub fn random_field_tests(type_name: String) { let mut rng = XorShiftRng::from_seed([ 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, @@ -207,3 +209,23 @@ fn random_expansion_tests(mut rng: R, type_name: String) { } end_timer!(start); } + +pub fn random_serialization_test(type_name: String) { + let mut rng = XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, + 0xe5, + ]); + let message = format!("serialization {}", type_name); + let start = start_timer!(|| message); + for _ in 0..1000000 { + let a = F::random(&mut rng); + let bytes = a.to_raw_bytes(); + let b = F::from_raw_bytes(&bytes).unwrap(); + assert_eq!(a, b); + let mut buf = Vec::new(); + a.write_raw(&mut buf).unwrap(); + let b = F::read_raw(&mut &buf[..]).unwrap(); + assert_eq!(a, b); + } + end_timer!(start); +}