Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stake-pool: Remove unsafe pointer casts via Pod types #5185

Merged
merged 3 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions stake-pool/cli/src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use {
solana_cli_output::{QuietDisplay, VerboseDisplay},
solana_sdk::native_token::Sol,
solana_sdk::{pubkey::Pubkey, stake::state::Lockup},
spl_stake_pool::state::{Fee, StakePool, StakeStatus, ValidatorList, ValidatorStakeInfo},
spl_stake_pool::state::{
Fee, PodStakeStatus, StakePool, StakeStatus, ValidatorList, ValidatorStakeInfo,
},
std::fmt::{Display, Formatter, Result, Write},
};

Expand Down Expand Up @@ -384,8 +386,9 @@ impl From<ValidatorStakeInfo> for CliStakePoolValidator {
}
}

impl From<StakeStatus> for CliStakePoolValidatorStakeStatus {
fn from(s: StakeStatus) -> CliStakePoolValidatorStakeStatus {
impl From<PodStakeStatus> for CliStakePoolValidatorStakeStatus {
fn from(s: PodStakeStatus) -> CliStakePoolValidatorStakeStatus {
let s = StakeStatus::try_from(s).unwrap();
match s {
StakeStatus::Active => CliStakePoolValidatorStakeStatus::Active,
StakeStatus::DeactivatingTransient => {
Expand Down
1 change: 1 addition & 0 deletions stake-pool/program/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ test-sbf = []
[dependencies]
arrayref = "0.3.7"
borsh = "0.10"
bytemuck = "1.13"
num-derive = "0.4"
num-traits = "0.2"
num_enum = "0.7.0"
Expand Down
175 changes: 52 additions & 123 deletions stake-pool/program/src/big_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
use {
arrayref::array_ref,
borsh::BorshDeserialize,
solana_program::{
program_error::ProgramError, program_memory::sol_memmove, program_pack::Pack,
},
std::marker::PhantomData,
bytemuck::Pod,
solana_program::{program_error::ProgramError, program_memory::sol_memmove},
std::mem,
};

/// Contains easy to use utilities for a big vector of Borsh-compatible types,
Expand All @@ -32,7 +31,7 @@ impl<'data> BigVec<'data> {
}

/// Retain all elements that match the provided function, discard all others
pub fn retain<T: Pack, F: Fn(&[u8]) -> bool>(
pub fn retain<T: Pod, F: Fn(&[u8]) -> bool>(
&mut self,
predicate: F,
) -> Result<(), ProgramError> {
Expand All @@ -42,12 +41,12 @@ impl<'data> BigVec<'data> {

let data_start_index = VEC_SIZE_BYTES;
let data_end_index =
data_start_index.saturating_add((vec_len as usize).saturating_mul(T::LEN));
for start_index in (data_start_index..data_end_index).step_by(T::LEN) {
let end_index = start_index + T::LEN;
data_start_index.saturating_add((vec_len as usize).saturating_mul(mem::size_of::<T>()));
for start_index in (data_start_index..data_end_index).step_by(mem::size_of::<T>()) {
let end_index = start_index + mem::size_of::<T>();
let slice = &self.data[start_index..end_index];
if !predicate(slice) {
let gap = removals_found * T::LEN;
let gap = removals_found * mem::size_of::<T>();
if removals_found > 0 {
// In case the compute budget is ever bumped up, allowing us
// to use this safe code instead:
Expand All @@ -68,7 +67,7 @@ impl<'data> BigVec<'data> {

// final memmove
if removals_found > 0 {
let gap = removals_found * T::LEN;
let gap = removals_found * mem::size_of::<T>();
// In case the compute budget is ever bumped up, allowing us
// to use this safe code instead:
//self.data.copy_within(dst_start_index + gap..data_end_index, dst_start_index);
Expand All @@ -88,11 +87,11 @@ impl<'data> BigVec<'data> {
}

/// Extracts a slice of the data types
pub fn deserialize_mut_slice<T: Pack>(
pub fn deserialize_mut_slice<T: Pod>(
&mut self,
skip: usize,
len: usize,
) -> Result<Vec<&'data mut T>, ProgramError> {
) -> Result<&mut [T], ProgramError> {
let vec_len = self.len();
let last_item_index = skip
.checked_add(len)
Expand All @@ -101,66 +100,60 @@ impl<'data> BigVec<'data> {
return Err(ProgramError::AccountDataTooSmall);
}

let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(T::LEN));
let end_index = start_index.saturating_add(len.saturating_mul(T::LEN));
let mut deserialized = vec![];
for slice in self.data[start_index..end_index].chunks_exact_mut(T::LEN) {
deserialized.push(unsafe { &mut *(slice.as_ptr() as *mut T) });
let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
bytemuck::try_cast_slice_mut(&mut self.data[start_index..end_index])
.map_err(|_| ProgramError::InvalidAccountData)
}

/// Extracts a slice of the data types
pub fn deserialize_slice<T: Pod>(&self, skip: usize, len: usize) -> Result<&[T], ProgramError> {
let vec_len = self.len();
let last_item_index = skip
.checked_add(len)
.ok_or(ProgramError::AccountDataTooSmall)?;
if last_item_index > vec_len as usize {
return Err(ProgramError::AccountDataTooSmall);
}
Ok(deserialized)

let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
bytemuck::try_cast_slice(&self.data[start_index..end_index])
.map_err(|_| ProgramError::InvalidAccountData)
}

/// Add new element to the end
pub fn push<T: Pack>(&mut self, element: T) -> Result<(), ProgramError> {
pub fn push<T: Pod>(&mut self, element: T) -> Result<(), ProgramError> {
let mut vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
let mut vec_len = u32::try_from_slice(vec_len_ref)?;

let start_index = VEC_SIZE_BYTES + vec_len as usize * T::LEN;
let end_index = start_index + T::LEN;
let start_index = VEC_SIZE_BYTES + vec_len as usize * mem::size_of::<T>();
let end_index = start_index + mem::size_of::<T>();

vec_len += 1;
borsh::to_writer(&mut vec_len_ref, &vec_len)?;

if self.data.len() < end_index {
return Err(ProgramError::AccountDataTooSmall);
}
let element_ref = &mut self.data[start_index..start_index + T::LEN];
element.pack_into_slice(element_ref);
let element_ref = bytemuck::try_from_bytes_mut(
&mut self.data[start_index..start_index + mem::size_of::<T>()],
)
.map_err(|_| ProgramError::InvalidAccountData)?;
*element_ref = element;
Ok(())
}

/// Get an iterator for the type provided
pub fn iter<'vec, T: Pack>(&'vec self) -> Iter<'data, 'vec, T> {
Iter {
len: self.len() as usize,
current: 0,
current_index: VEC_SIZE_BYTES,
inner: self,
phantom: PhantomData,
}
}

/// Get a mutable iterator for the type provided
pub fn iter_mut<'vec, T: Pack>(&'vec mut self) -> IterMut<'data, 'vec, T> {
IterMut {
len: self.len() as usize,
current: 0,
current_index: VEC_SIZE_BYTES,
inner: self,
phantom: PhantomData,
}
}

/// Find matching data in the array
pub fn find<T: Pack, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
pub fn find<T: Pod, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let end_index = current_index + mem::size_of::<T>();
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice) {
return Some(unsafe { &*(current_slice.as_ptr() as *const T) });
return Some(bytemuck::from_bytes(current_slice));
}
current_index = end_index;
current += 1;
Expand All @@ -169,15 +162,17 @@ impl<'data> BigVec<'data> {
}

/// Find matching data in the array
pub fn find_mut<T: Pack, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
pub fn find_mut<T: Pod, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let end_index = current_index + mem::size_of::<T>();
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice) {
return Some(unsafe { &mut *(current_slice.as_ptr() as *mut T) });
return Some(bytemuck::from_bytes_mut(
&mut self.data[current_index..end_index],
));
}
current_index = end_index;
current += 1;
Expand All @@ -186,84 +181,16 @@ impl<'data> BigVec<'data> {
}
}

/// Iterator wrapper over a BigVec
pub struct Iter<'data, 'vec, T> {
len: usize,
current: usize,
current_index: usize,
inner: &'vec BigVec<'data>,
phantom: PhantomData<T>,
}

impl<'data, 'vec, T: Pack + 'data> Iterator for Iter<'data, 'vec, T> {
type Item = &'data T;

fn next(&mut self) -> Option<Self::Item> {
if self.current == self.len {
None
} else {
let end_index = self.current_index + T::LEN;
let value = Some(unsafe {
&*(self.inner.data[self.current_index..end_index].as_ptr() as *const T)
});
self.current += 1;
self.current_index = end_index;
value
}
}
}

/// Iterator wrapper over a BigVec
pub struct IterMut<'data, 'vec, T> {
len: usize,
current: usize,
current_index: usize,
inner: &'vec mut BigVec<'data>,
phantom: PhantomData<T>,
}

impl<'data, 'vec, T: Pack + 'data> Iterator for IterMut<'data, 'vec, T> {
type Item = &'data mut T;

fn next(&mut self) -> Option<Self::Item> {
if self.current == self.len {
None
} else {
let end_index = self.current_index + T::LEN;
let value = Some(unsafe {
&mut *(self.inner.data[self.current_index..end_index].as_ptr() as *mut T)
});
self.current += 1;
self.current_index = end_index;
value
}
}
}

#[cfg(test)]
mod tests {
use {super::*, solana_program::program_pack::Sealed};
use {super::*, bytemuck::Zeroable};

#[derive(Debug, PartialEq)]
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq, Pod, Zeroable)]
struct TestStruct {
value: [u8; 8],
}

impl Sealed for TestStruct {}

impl Pack for TestStruct {
const LEN: usize = 8;
fn pack_into_slice(&self, data: &mut [u8]) {
let mut data = data;
borsh::to_writer(&mut data, &self.value).unwrap();
}
fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
Ok(TestStruct {
value: src.try_into().unwrap(),
})
}
}

impl TestStruct {
fn new(value: u8) -> Self {
let value = [value, 0, 0, 0, 0, 0, 0, 0];
Expand All @@ -281,7 +208,9 @@ mod tests {

fn check_big_vec_eq(big_vec: &BigVec, slice: &[u8]) {
assert!(big_vec
.iter::<TestStruct>()
.deserialize_slice::<TestStruct>(0, big_vec.len() as usize)
.unwrap()
.iter()
.map(|x| &x.value[0])
.zip(slice.iter())
.all(|(a, b)| a == b));
Expand Down
Loading
Loading