diff --git a/derive/src/field/mod.rs b/derive/src/field/mod.rs index cfcbd798..37c1da17 100644 --- a/derive/src/field/mod.rs +++ b/derive/src/field/mod.rs @@ -94,6 +94,10 @@ pub(crate) fn impl_field(input: TokenStream) -> TokenStream { let modulus_limbs = crate::utils::big_to_limbs(&modulus, num_limbs); let modulus_str = format!("0x{}", modulus.to_str_radix(16)); let modulus_limbs_ident = quote! {[#(#modulus_limbs,)*]}; + + let modulus_limbs_32 = crate::utils::big_to_limbs_32(&modulus, num_limbs * 2); + let modulus_limbs_32_ident = quote! {[#(#modulus_limbs_32,)*]}; + let to_token = |e: &BigUint| big_to_token(e, num_limbs); // binary modulus @@ -265,6 +269,7 @@ pub(crate) fn impl_field(input: TokenStream) -> TokenStream { pub const SIZE: usize = #num_limbs * 8; pub const NUM_LIMBS: usize = #num_limbs; pub(crate) const MODULUS_LIMBS: [u64; Self::NUM_LIMBS] = #modulus_limbs_ident; + pub(crate) const MODULUS_LIMBS_32: [u32; Self::NUM_LIMBS*2] = #modulus_limbs_32_ident; const R: Self = Self(#r1); const R2: Self = Self(#r2); const R3: Self = Self(#r3); diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 91e4ef05..9b5cd3f8 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -18,6 +18,13 @@ pub(crate) fn big_to_limbs(e: &BigUint, number_of_limbs: usize) -> Vec { .collect() } +pub(crate) fn big_to_limbs_32(e: &BigUint, number_of_limbs: usize) -> Vec { + decompose(e, number_of_limbs, 32) + .iter() + .map(|x| x.to_u32().unwrap()) + .collect() +} + pub(crate) fn big_to_token(e: &BigUint, number_of_limbs: usize) -> proc_macro2::TokenStream { let limbs = big_to_limbs(e, number_of_limbs); quote::quote! {[#(#limbs,)*]} diff --git a/src/derive/field/common.rs b/src/derive/field/common.rs index cb78598e..ff3b7772 100644 --- a/src/derive/field/common.rs +++ b/src/derive/field/common.rs @@ -20,7 +20,16 @@ macro_rules! field_bits { let limbs = (0..Self::NUM_LIMBS * 8 / STEP) .map(|off| { - u64::from_le_bytes(bytes[off * STEP..(off + 1) * STEP].try_into().unwrap()) + #[cfg(target_pointer_width = "64")] + let limb = u64::from_le_bytes( + bytes[off * STEP..(off + 1) * STEP].try_into().unwrap(), + ); + #[cfg(not(target_pointer_width = "64"))] + let limb = u32::from_le_bytes( + bytes[off * STEP..(off + 1) * STEP].try_into().unwrap(), + ); + + limb }) .collect::>(); @@ -31,7 +40,7 @@ macro_rules! field_bits { #[cfg(target_pointer_width = "64")] let bits = ff::FieldBits::new(Self::MODULUS_LIMBS); #[cfg(not(target_pointer_width = "64"))] - let bits = ff::FieldBits::new(MODULUS_LIMBS_32.0); + let bits = ff::FieldBits::new(Self::MODULUS_LIMBS_32); bits }