Skip to content

Commit

Permalink
Merge pull request #395 from unum-cloud/main-dev
Browse files Browse the repository at this point in the history
Improve Quantization in Rust
  • Loading branch information
ashvardanian committed Apr 12, 2024
2 parents 42dd660 + 63116c0 commit 0af3a3c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 24 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
"Println",
"pytest",
"Quickstart",
"repr",
"rtype",
"SIMD",
"simsimd",
Expand Down
97 changes: 73 additions & 24 deletions rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ pub type Key = u64;
pub type Distance = f32;

/// Callback signature for custom metric functions, defined in the Rust layer and used in the C++ layer.
pub type StatefullMetric = unsafe extern "C" fn(
pub type StatefulMetric = unsafe extern "C" fn(
*const std::ffi::c_void,
*const std::ffi::c_void,
*mut std::ffi::c_void,
) -> Distance;

/// Callback signature for custom predicate functions, defined in the Rust layer and used in the C++ layer.
pub type StatefullPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool;
pub type StatefulPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool;

/// Represents errors that can occur when addressing bits.
#[derive(Debug)]
Expand Down Expand Up @@ -73,31 +73,76 @@ pub trait BitAddressable {
fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError>;
}

/// A byte-wide bit vector type that provides low-level control over individual bits.
///
/// This struct represents a single byte (8 bits) and enables manipulation and
/// interpretation of individual bits via various utility functions.
#[repr(transparent)]
#[allow(non_camel_case_types)]
#[derive(Clone, Copy, Eq, PartialEq)]
pub struct b1x8(pub u8);

impl b1x8 {
/// Casts a slice of `u8` to a slice of `b1x8`.
/// Casts a slice of `u8` bytes to a slice of `b1x8`, allowing bit-level operations on byte slices.
pub fn from_u8s(slice: &[u8]) -> &[Self] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
}

/// Casts a mutable slice of `u8` to a mutable slice of `b1x8`.
/// Casts a mutable slice of `u8` bytes to a mutable slice of `b1x8`, enabling mutable
/// bit-level operations on byte slices.
pub fn from_mut_u8s(slice: &mut [u8]) -> &mut [Self] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
}

/// Casts a slice of `b1x8` back to a slice of `u8`.
/// Converts a slice of `b1x8` back to a slice of `u8`, useful for reading bit-level manipulations
/// in byte-oriented contexts.
pub fn to_u8s(slice: &[Self]) -> &[u8] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len()) }
}

/// Casts a mutable slice of `b1x8` back to a mutable slice of `u8`.
/// Converts a mutable slice of `b1x8` back to a mutable slice of `u8`, enabling further
/// modifications on the original byte data after bit-level manipulations.
pub fn to_mut_u8s(slice: &mut [Self]) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, slice.len()) }
}
}

/// A struct representing a half-precision floating-point number based on the IEEE 754 standard.
///
/// This struct uses an `i16` to store the half-precision floating-point data, which includes
/// 1 sign bit, 5 exponent bits, and 10 mantissa bits.
#[repr(transparent)]
#[allow(non_camel_case_types)]
#[derive(Clone, Copy)]
pub struct f16(i16);

impl f16 {
/// Casts a slice of `i16` integers to a slice of `f16`, allowing operations on half-precision
/// floating-point data stored in standard 16-bit integer arrays.
pub fn from_i16s(slice: &[i16]) -> &[Self] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
}

/// Casts a mutable slice of `i16` integers to a mutable slice of `f16`, enabling mutable operations
/// on half-precision floating-point data.
pub fn from_mut_i16s(slice: &mut [i16]) -> &mut [Self] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
}

/// Converts a slice of `f16` back to a slice of `i16`, useful for storage or manipulation in formats
/// that require standard integer types.
pub fn to_i16s(slice: &[Self]) -> &[i16] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const i16, slice.len()) }
}

/// Converts a mutable slice of `f16` back to a mutable slice of `i16`, enabling further
/// modifications on the original integer data after operations involving half-precision
/// floating-point numbers.
pub fn to_mut_i16s(slice: &mut [Self]) -> &mut [i16] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut i16, slice.len()) }
}
}

impl BitAddressable for b1x8 {
/// Sets a bit at a specific index within the byte.
///
Expand Down Expand Up @@ -167,28 +212,32 @@ impl BitAddressable for [b1x8] {
}
}

#[repr(transparent)]
#[allow(non_camel_case_types)]
pub struct f16(i16);
impl f16 {
/// Casts a slice of `i16` to a slice of `f16`.
pub fn from_i16s(slice: &[i16]) -> &[Self] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) }
}
impl PartialEq for f16 {
fn eq(&self, other: &Self) -> bool {
// Check for NaN values first (exponent all ones and non-zero mantissa)
let nan_self = (self.0 & 0x7C00) == 0x7C00 && (self.0 & 0x03FF) != 0;
let nan_other = (other.0 & 0x7C00) == 0x7C00 && (other.0 & 0x03FF) != 0;
if nan_self || nan_other {
return false;
}

/// Casts a mutable slice of `i16` to a mutable slice of `f16`.
pub fn from_mut_i16s(slice: &mut [i16]) -> &mut [Self] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) }
self.0 == other.0
}
}

/// Casts a slice of `f16` back to a slice of `i16`.
pub fn to_i16s(slice: &[Self]) -> &[i16] {
unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const i16, slice.len()) }
impl std::fmt::Debug for b1x8 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:08b}", self.0)
}
}

/// Casts a mutable slice of `f16` back to a mutable slice of `i16`.
pub fn to_mut_i16s(slice: &mut [Self]) -> &mut [i16] {
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut i16, slice.len()) }
impl std::fmt::Debug for f16 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let bits = self.0;
let sign = (bits >> 15) & 1;
let exponent = (bits >> 10) & 0x1F;
let mantissa = bits & 0x3FF;
write!(f, "{}|{:05b}|{:010b}", sign, exponent, mantissa)
}
}

Expand Down Expand Up @@ -273,7 +322,7 @@ pub mod ffi {
pub fn change_metric_kind(self: &NativeIndex, metric: MetricKind);

/// Changes the metric function used to calculate the distance between vectors.
/// Avoids the `std::ffi::c_void` type and the `StatefullMetric` type, that the FFI
/// Avoids the `std::ffi::c_void` type and the `StatefulMetric` type, that the FFI
/// does not support, replacing them with basic pointer-sized integer types.
/// The first two arguments are the pointers to the vectors to compare, and the third
/// argument is the `metric_state` propagated from the Rust layer.
Expand Down

0 comments on commit 0af3a3c

Please sign in to comment.