Skip to content

Commit

Permalink
only implement Message for iterators
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
  • Loading branch information
andrewwhitehead committed Aug 8, 2022
1 parent 5ca2178 commit 3102b60
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 65 deletions.
70 changes: 10 additions & 60 deletions src/hash_to_curve/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,71 +102,21 @@ pub trait Message {
///
/// The parameters to successive calls to `f` are treated as a
/// single concatenated octet string.
fn consume(self, f: impl FnMut(&[u8]));
fn input_message(self, f: impl FnMut(&[u8]));
}

impl Message for &[u8] {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self)
}
}

impl<const N: usize> Message for &[u8; N] {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self)
}
}

impl Message for &str {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_bytes())
}
}

impl Message for &[&[u8]] {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
impl<M, I> Message for I
where
M: AsRef<[u8]>,
I: IntoIterator<Item = M>,
{
fn input_message(self, mut f: impl FnMut(&[u8])) {
for msg in self {
f(msg);
f(msg.as_ref())
}
}
}

#[cfg(feature = "alloc")]
impl Message for Vec<u8> {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_slice())
}
}

#[cfg(feature = "alloc")]
impl Message for &Vec<u8> {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_slice())
}
}

#[cfg(feature = "alloc")]
impl Message for alloc::string::String {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_bytes())
}
}

#[cfg(feature = "alloc")]
impl Message for &alloc::string::String {
#[inline]
fn consume(self, mut f: impl FnMut(&[u8])) {
f(self.as_bytes())
}
}

/// A trait for message expansion methods supported by hash-to-curve.
pub trait ExpandMessage {
/// Initializes a message expander.
Expand Down Expand Up @@ -230,7 +180,7 @@ where

let dst = ExpandMsgDst::for_xof::<H, L>(dst);
let mut hash = H::default();
message.consume(|m| hash.update(m));
message.input_message(|m| hash.update(m));
let reader = hash
.chain((len_in_bytes as u16).to_be_bytes())
.chain(dst.data())
Expand Down Expand Up @@ -294,7 +244,7 @@ where
let dst = ExpandMsgDst::for_xmd::<H>(dst);
let mut hash_b_0 =
H::default().chain(GenericArray::<u8, <H as BlockInput>::BlockSize>::default());
message.consume(|m| hash_b_0.update(m));
message.input_message(|m| hash_b_0.update(m));
let b_0 = hash_b_0
.chain((len_in_bytes as u16).to_be_bytes())
.chain([0u8])
Expand Down
22 changes: 21 additions & 1 deletion tests/expand_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@ mod tests {
use sha2::{Sha256, Sha512};
use sha3::{Shake128, Shake256};

#[test]
fn test_expand_message_parts() {
const EXPAND_LEN: usize = 16;
let mut b1 = [0u8; EXPAND_LEN];
let mut b2 = [0u8; EXPAND_LEN];
<ExpandMsgXmd<Sha256> as ExpandMessage>::init_expand::<_, U32>(
[b"sig" as &[u8], b"nature"],
&[],
EXPAND_LEN,
)
.read_into(&mut b1);
<ExpandMsgXmd<Sha256> as ExpandMessage>::init_expand::<_, U32>(
[b"signature"],
&[],
EXPAND_LEN,
)
.read_into(&mut b2);
assert_eq!(b1, b2);
}

struct TestCase {
msg: &'static [u8],
dst: &'static [u8],
Expand All @@ -19,7 +39,7 @@ mod tests {
pub fn run<E: ExpandMessage>(self) {
let mut buf = [0u8; 128];
let output = &mut buf[..self.len_in_bytes];
E::init_expand::<_, U32>(self.msg, self.dst, self.len_in_bytes).read_into(output);
E::init_expand::<_, U32>([self.msg], self.dst, self.len_in_bytes).read_into(output);
if output != self.uniform_bytes {
panic!(
"Failed: expand_message.\n\
Expand Down
6 changes: 4 additions & 2 deletions tests/hash_to_curve_g1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ mod tests {

for case in cases {
let g = <G1Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::hash_to_curve(
case.msg, case.dst,
[case.msg],
case.dst,
);
let aff = G1Affine::from(g);
let g_uncompressed = aff.to_uncompressed();
Expand Down Expand Up @@ -179,7 +180,8 @@ mod tests {

for case in cases {
let g = <G1Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::encode_to_curve(
case.msg, case.dst,
[case.msg],
case.dst,
);
let aff = G1Affine::from(g);
let g_uncompressed = aff.to_uncompressed();
Expand Down
6 changes: 4 additions & 2 deletions tests/hash_to_curve_g2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ mod tests {

for case in cases {
let g = <G2Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::hash_to_curve(
case.msg, case.dst,
[case.msg],
case.dst,
);
let aff = G2Affine::from(g);
let g_uncompressed = aff.to_uncompressed();
Expand Down Expand Up @@ -219,7 +220,8 @@ mod tests {

for case in cases {
let g = <G2Projective as HashToCurve<ExpandMsgXmd<Sha256>>>::encode_to_curve(
case.msg, case.dst,
[case.msg],
case.dst,
);
let aff = G2Affine::from(g);
let g_uncompressed = aff.to_uncompressed();
Expand Down

0 comments on commit 3102b60

Please sign in to comment.