Skip to content

Commit

Permalink
Refactor & add testing
Browse files Browse the repository at this point in the history
  • Loading branch information
nyunyunyunyu committed May 11, 2023
1 parent f8cda76 commit a8b1804
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 67 deletions.
3 changes: 3 additions & 0 deletions halo2-base/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
//! Base library to build Halo2 circuits.
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
#![feature(const_cmp)]
#![feature(stmt_expr_attributes)]
#![feature(trait_alias)]
#![deny(clippy::perf)]
Expand Down
Empty file.
125 changes: 58 additions & 67 deletions halo2-base/src/safe_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,110 +8,101 @@ pub use crate::{
QuantumCell::{self, Constant, Existing, Witness},
};
use std::cmp::{max, min};
use std::sync::Arc;

#[cfg(test)]
pub mod tests;

type RawAssignedValues<F> = Vec<AssignedValue<F>>;

const BITS_PER_BYTE: usize = 8;
// Each AssignedValue can at most represent 8 bytes.
const MAX_BYTE_PER_ELEMENT: usize = 8;

// SafeType's goal is to avoid out-of-range undefined behavior.
// When building circuits, it's common to use mulitple AssignedValue<F> to represent
// a logical varaible. For example, we might want to represent a hash with 32 AssignedValue<F>
// where each AssignedValue represents 1 byte. However, the range of AssignedValue<F> is much
// larger than 1 byte(0~255). If a circuit takes 32 AssignedValue<F> as inputs and some of them
// are actually greater than 255, there could be some undefined behaviors.
// SafeType gurantees the value range of its owned AssignedValue<F>. So circuits don't need to
// do any extra value checking if they take SafeType as inputs.
/// SafeType's goal is to avoid out-of-range undefined behavior.
/// When building circuits, it's common to use mulitple AssignedValue<F> to represent
/// a logical varaible. For example, we might want to represent a hash with 32 AssignedValue<F>
/// where each AssignedValue represents 1 byte. However, the range of AssignedValue<F> is much
/// larger than 1 byte(0~255). If a circuit takes 32 AssignedValue<F> as inputs and some of them
/// are actually greater than 255, there could be some undefined behaviors.
/// SafeType gurantees the value range of its owned AssignedValue<F>. So circuits don't need to
/// do any extra value checking if they take SafeType as inputs.
#[derive(Clone, Debug)]
pub struct SafeType<F: ScalarField, const B: usize, const L: usize> {
value: RawAssignedValues<F>,
pub struct SafeType<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>
where [(); (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE)]: Sized {
// value is stored in little-endian. (BYTES_PER_ELE * BITS_PER_BYTE) is the number of bits of a single element.
value: [AssignedValue<F>; (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE)],
}

impl<F: ScalarField, const B: usize, const L: usize> SafeType<F, B, L> {
impl<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize> SafeType<F, BYTES_PER_ELE, TOTAL_BITS>
where [(); (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE)]: Sized {
pub const BYTES_PER_ELE: usize = BYTES_PER_ELE;
pub const TOTAL_BITS: usize = TOTAL_BITS;
pub const BITS_PER_ELE: usize = min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE);
pub const VALUE_LENGTH: usize = (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE);

// new is private so Safetype can only be constructed by this crate.
fn new(raw_values: RawAssignedValues<F>) -> Self {
Self { value: raw_values }
Self { value: raw_values.try_into().unwrap() }
}

// Return values in littile-endian.
pub fn value(&self) -> &RawAssignedValues<F> {
pub fn value(&self) -> &[AssignedValue<F>; (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE)] {
&self.value
}

// length of value() must equal to value_length().
pub fn value_length() -> usize {
L
}

// All elements in value() need to be in [0, element_limit()].
pub fn element_limit() -> u64 {
((1u128 << B) - 1) as u64
}

// Each element in value() has element_bits() bits.
pub fn element_bits() -> usize {
B
}
}

pub type SafeBool<F> = SafeType<F, 1, 1>;
pub type SafeUint8<F> = SafeType<F, 8, 1>;
pub type SafeUint16<F> = SafeType<F, 16, 1>;
pub type SafeUint32<F> = SafeType<F, 32, 1>;
pub type SafeUint64<F> = SafeType<F, 64, 1>;
pub type SafeUint128<F> = SafeType<F, 64, 2>;
pub type SafeUint256<F> = SafeType<F, 64, 4>;
pub type SafeBytes32<F> = SafeType<F, 8, 32>;
// (2^(F::NUM_BITS) - 1) might not be a valid value for F. e.g. max value of F is a prime in [2^(F::NUM_BITS-1), 2^(F::NUM_BITS) - 1]
type CompactSafeType<F: ScalarField, const TOTAL_BITS: usize> = SafeType<F, { ((F::NUM_BITS - 1) / 8) as usize}, TOTAL_BITS>;

pub type SafeBool<F> = CompactSafeType<F, 1>;
pub type SafeUint8<F> = CompactSafeType<F, 8>;
pub type SafeUint16<F> = CompactSafeType<F, 16>;
pub type SafeUint32<F> = CompactSafeType<F, 32>;
pub type SafeUint64<F> = CompactSafeType<F, 64>;
pub type SafeUint128<F> = CompactSafeType<F, 128>;
pub type SafeUint256<F> = CompactSafeType<F, 256>;
pub type SafeBytes32<F> = SafeType<F, 1, 256>;

pub struct SafeTypeChip<F: ScalarField> {
pub range_chip: RangeChip<F>,
pub byte_bases: Vec<QuantumCell<F>>,
pub range_chip: Arc<RangeChip<F>>,
}

impl<F: ScalarField> SafeTypeChip<F> {
pub fn new(lookup_bits: usize) -> Self {
let byte_base = F::from(1u64 << BITS_PER_BYTE);
let mut running_base = F::one();
let num_bases = MAX_BYTE_PER_ELEMENT;
let mut byte_bases = Vec::with_capacity(num_bases);
for _ in 0..num_bases {
byte_bases.push(Constant(running_base));
running_base *= &byte_base;
}

Self { range_chip: RangeChip::default(lookup_bits), byte_bases }
impl<F: ScalarField> SafeTypeChip< F> {
pub fn new(range_chip: Arc<RangeChip<F>>) -> Self {
Self { range_chip: Arc::clone(&range_chip) }
}

pub fn raw_bytes_to<const B: usize, const L: usize>(
pub fn raw_bytes_to<const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>(
&self,
ctx: &mut Context<F>,
inputs: RawAssignedValues<F>,
) -> SafeType<F, B, L> {
let value_length = SafeType::<F, B, L>::value_length();
let element_bits = SafeType::<F, B, L>::element_bits();
let bits = value_length * element_bits;
) -> SafeType<F, BYTES_PER_ELE, TOTAL_BITS>
where [(); (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE)]: Sized {
let element_bits = SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::BITS_PER_ELE;
let bits = TOTAL_BITS;
assert!(
inputs.len() * BITS_PER_BYTE == max(bits, BITS_PER_BYTE),
"number of bits doesn't match"
);
self.add_bytes_constraints(ctx, &inputs, bits);
if value_length == 1 || element_bits == BITS_PER_BYTE {
return SafeType::<F, B, L>::new(inputs);
// inputs is a bool or uint8.
if element_bits == BITS_PER_BYTE {
return SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(inputs);
};
let bytes_per_element = element_bits / BITS_PER_BYTE;

let mut value = vec![];
for i in 0..value_length {
let start = i * bytes_per_element;
let end = start + bytes_per_element;
let mut byte_base = vec![];
for i in 0..BYTES_PER_ELE {
byte_base.push(Witness(self.range_chip.gate.pow_of_two[i * BITS_PER_BYTE]));
}
for chunk in inputs.chunks(BYTES_PER_ELE) {
let acc = self.range_chip.gate.inner_product(
ctx,
inputs[start..end].to_vec(),
self.byte_bases[..bytes_per_element].to_vec(),
chunk.to_vec(),
byte_base[..chunk.len()].to_vec(),
);
value.push(acc);
}
SafeType::<F, B, L>::new(value)
SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(value)
}

fn add_bytes_constraints(
Expand All @@ -123,7 +114,7 @@ impl<F: ScalarField> SafeTypeChip<F> {
let mut bits_left = bits;
for input in inputs {
let num_bit = min(bits_left, BITS_PER_BYTE);
self.range_chip.check_less_than_safe(ctx, *input, 1u64 << num_bit);
self.range_chip.range_check(ctx, *input,num_bit);
bits_left -= num_bit;
}
}
Expand Down
Loading

0 comments on commit a8b1804

Please sign in to comment.