From 3102b6047b7e5771740fbcb81ce28e2f4483c08d Mon Sep 17 00:00:00 2001 From: Andrew Whitehead Date: Mon, 8 Aug 2022 16:23:32 -0700 Subject: [PATCH] only implement Message for iterators Signed-off-by: Andrew Whitehead --- src/hash_to_curve/expand_msg.rs | 70 +++++---------------------------- tests/expand_msg.rs | 22 ++++++++++- tests/hash_to_curve_g1.rs | 6 ++- tests/hash_to_curve_g2.rs | 6 ++- 4 files changed, 39 insertions(+), 65 deletions(-) diff --git a/src/hash_to_curve/expand_msg.rs b/src/hash_to_curve/expand_msg.rs index 4e816fd5..33a3d9f9 100644 --- a/src/hash_to_curve/expand_msg.rs +++ b/src/hash_to_curve/expand_msg.rs @@ -102,71 +102,21 @@ pub trait Message { /// /// The parameters to successive calls to `f` are treated as a /// single concatenated octet string. - fn consume(self, f: impl FnMut(&[u8])); + fn input_message(self, f: impl FnMut(&[u8])); } -impl Message for &[u8] { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self) - } -} - -impl Message for &[u8; N] { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self) - } -} - -impl Message for &str { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_bytes()) - } -} - -impl Message for &[&[u8]] { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { +impl Message for I +where + M: AsRef<[u8]>, + I: IntoIterator, +{ + fn input_message(self, mut f: impl FnMut(&[u8])) { for msg in self { - f(msg); + f(msg.as_ref()) } } } -#[cfg(feature = "alloc")] -impl Message for Vec { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_slice()) - } -} - -#[cfg(feature = "alloc")] -impl Message for &Vec { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_slice()) - } -} - -#[cfg(feature = "alloc")] -impl Message for alloc::string::String { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_bytes()) - } -} - -#[cfg(feature = "alloc")] -impl Message for &alloc::string::String { - #[inline] - fn consume(self, mut f: impl FnMut(&[u8])) { - f(self.as_bytes()) - } -} - /// A trait for message expansion methods supported by hash-to-curve. pub trait ExpandMessage { /// Initializes a message expander. @@ -230,7 +180,7 @@ where let dst = ExpandMsgDst::for_xof::(dst); let mut hash = H::default(); - message.consume(|m| hash.update(m)); + message.input_message(|m| hash.update(m)); let reader = hash .chain((len_in_bytes as u16).to_be_bytes()) .chain(dst.data()) @@ -294,7 +244,7 @@ where let dst = ExpandMsgDst::for_xmd::(dst); let mut hash_b_0 = H::default().chain(GenericArray::::BlockSize>::default()); - message.consume(|m| hash_b_0.update(m)); + message.input_message(|m| hash_b_0.update(m)); let b_0 = hash_b_0 .chain((len_in_bytes as u16).to_be_bytes()) .chain([0u8]) diff --git a/tests/expand_msg.rs b/tests/expand_msg.rs index 0154a2e2..28e9432a 100644 --- a/tests/expand_msg.rs +++ b/tests/expand_msg.rs @@ -7,6 +7,26 @@ mod tests { use sha2::{Sha256, Sha512}; use sha3::{Shake128, Shake256}; + #[test] + fn test_expand_message_parts() { + const EXPAND_LEN: usize = 16; + let mut b1 = [0u8; EXPAND_LEN]; + let mut b2 = [0u8; EXPAND_LEN]; + as ExpandMessage>::init_expand::<_, U32>( + [b"sig" as &[u8], b"nature"], + &[], + EXPAND_LEN, + ) + .read_into(&mut b1); + as ExpandMessage>::init_expand::<_, U32>( + [b"signature"], + &[], + EXPAND_LEN, + ) + .read_into(&mut b2); + assert_eq!(b1, b2); + } + struct TestCase { msg: &'static [u8], dst: &'static [u8], @@ -19,7 +39,7 @@ mod tests { pub fn run(self) { let mut buf = [0u8; 128]; let output = &mut buf[..self.len_in_bytes]; - E::init_expand::<_, U32>(self.msg, self.dst, self.len_in_bytes).read_into(output); + E::init_expand::<_, U32>([self.msg], self.dst, self.len_in_bytes).read_into(output); if output != self.uniform_bytes { panic!( "Failed: expand_message.\n\ diff --git a/tests/hash_to_curve_g1.rs b/tests/hash_to_curve_g1.rs index be9db8e9..56d4bbc9 100644 --- a/tests/hash_to_curve_g1.rs +++ b/tests/hash_to_curve_g1.rs @@ -100,7 +100,8 @@ mod tests { for case in cases { let g = >>::hash_to_curve( - case.msg, case.dst, + [case.msg], + case.dst, ); let aff = G1Affine::from(g); let g_uncompressed = aff.to_uncompressed(); @@ -179,7 +180,8 @@ mod tests { for case in cases { let g = >>::encode_to_curve( - case.msg, case.dst, + [case.msg], + case.dst, ); let aff = G1Affine::from(g); let g_uncompressed = aff.to_uncompressed(); diff --git a/tests/hash_to_curve_g2.rs b/tests/hash_to_curve_g2.rs index c5228caa..4a516d6d 100644 --- a/tests/hash_to_curve_g2.rs +++ b/tests/hash_to_curve_g2.rs @@ -120,7 +120,8 @@ mod tests { for case in cases { let g = >>::hash_to_curve( - case.msg, case.dst, + [case.msg], + case.dst, ); let aff = G2Affine::from(g); let g_uncompressed = aff.to_uncompressed(); @@ -219,7 +220,8 @@ mod tests { for case in cases { let g = >>::encode_to_curve( - case.msg, case.dst, + [case.msg], + case.dst, ); let aff = G2Affine::from(g); let g_uncompressed = aff.to_uncompressed();