Skip to content

Commit

Permalink
Allow casting between slices of ZSTs and slices of non-ZSTs in all ca…
Browse files Browse the repository at this point in the history
…ses. (#256)

Casting ZST to non-ZST will result in a slice length of 0.
Casting non-ZST to ZST will only work if the input slice has length 0, and results in a slice length of 0; if the input slice is not of length 0, PodCastError::OutputSliceWouldHaveSlop is returned.

Updates the docs of the PodCastError variants to reflect when they can occur.
Updates the docs of try_cast_slice (and checked::) to remove note about ZST <-> non-ZST not being allowed.
Update bytes_of(_mut) to remove ZST check, since casting [ZST] -> [u8] is now allowed directly using cast_slice(_mut).
Update must_cast_slice checks and doctests to allow [ZST] -> [non-ZST], but disallow [non-ZST] -> [ZST].
  • Loading branch information
zachs18 committed Jul 30, 2024
1 parent 758774d commit 291a924
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 56 deletions.
21 changes: 13 additions & 8 deletions src/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ pub fn try_cast_slice_box<A: NoUninit, B: AnyBitPattern>(
{
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
Err((PodCastError::OutputSliceWouldHaveSlop, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Box
Expand Down Expand Up @@ -239,7 +239,7 @@ pub fn try_cast_vec<A: NoUninit, B: AnyBitPattern>(
// length and capacity are valid under B, as we do not want to
// change which bytes are considered part of the initialized slice
// of the Vec
Err((PodCastError::SizeMismatch, input))
Err((PodCastError::OutputSliceWouldHaveSlop, input))
} else {
// Because the size is an exact multiple, we can now change the length and
// capacity and recreate the Vec
Expand Down Expand Up @@ -431,7 +431,7 @@ pub fn try_cast_slice_rc<
{
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
Err((PodCastError::OutputSliceWouldHaveSlop, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Rc
Expand Down Expand Up @@ -499,7 +499,7 @@ pub fn try_cast_slice_arc<
{
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
Err((PodCastError::OutputSliceWouldHaveSlop, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Arc
Expand Down Expand Up @@ -850,13 +850,18 @@ impl<T: AnyBitPattern> sealed::FromBoxBytes for [T] {
let single_layout = Layout::new::<T>();
if bytes.layout.align() != single_layout.align() {
Err((PodCastError::AlignmentMismatch, bytes))
} else if single_layout.size() == 0 {
Err((PodCastError::SizeMismatch, bytes))
} else if bytes.layout.size() % single_layout.size() != 0 {
} else if (single_layout.size() == 0 && bytes.layout.size() != 0)
|| (single_layout.size() != 0
&& bytes.layout.size() % single_layout.size() != 0)
{
Err((PodCastError::OutputSliceWouldHaveSlop, bytes))
} else {
let (ptr, layout) = bytes.into_raw_parts();
let length = layout.size() / single_layout.size();
let length = if single_layout.size() != 0 {
layout.size() / single_layout.size()
} else {
0
};
let ptr =
core::ptr::slice_from_raw_parts_mut(ptr.as_ptr() as *mut T, length);
// SAFETY: See BoxBytes type invariant.
Expand Down
2 changes: 0 additions & 2 deletions src/checked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,6 @@ pub fn try_cast_mut<
/// type, and the output slice wouldn't be a whole number of elements when
/// accounting for the size change (eg: 3 `u16` values is 1.5 `u32` values, so
/// that's a failure).
/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts)
/// and a non-ZST.
/// * If any element of the converted slice would contain an invalid bit pattern
/// for `B` this fails.
#[inline]
Expand Down
42 changes: 18 additions & 24 deletions src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,9 @@ pub(crate) fn something_went_wrong<D>(_src: &str, _err: D) -> ! {
/// empty slice might not match the pointer value of the input reference.
#[inline(always)]
pub(crate) unsafe fn bytes_of<T: Copy>(t: &T) -> &[u8] {
if size_of::<T>() == 0 {
&[]
} else {
match try_cast_slice::<T, u8>(core::slice::from_ref(t)) {
Ok(s) => s,
Err(_) => unreachable!(),
}
match try_cast_slice::<T, u8>(core::slice::from_ref(t)) {
Ok(s) => s,
Err(_) => unreachable!(),
}
}

Expand All @@ -67,13 +63,9 @@ pub(crate) unsafe fn bytes_of<T: Copy>(t: &T) -> &[u8] {
/// empty slice might not match the pointer value of the input reference.
#[inline]
pub(crate) unsafe fn bytes_of_mut<T: Copy>(t: &mut T) -> &mut [u8] {
if size_of::<T>() == 0 {
&mut []
} else {
match try_cast_slice_mut::<T, u8>(core::slice::from_mut(t)) {
Ok(s) => s,
Err(_) => unreachable!(),
}
match try_cast_slice_mut::<T, u8>(core::slice::from_mut(t)) {
Ok(s) => s,
Err(_) => unreachable!(),
}
}

Expand Down Expand Up @@ -347,12 +339,11 @@ pub(crate) unsafe fn try_cast_mut<A: Copy, B: Copy>(
/// type, and the output slice wouldn't be a whole number of elements when
/// accounting for the size change (eg: 3 `u16` values is 1.5 `u32` values, so
/// that's a failure).
/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts)
/// and a non-ZST.
#[inline]
pub(crate) unsafe fn try_cast_slice<A: Copy, B: Copy>(
a: &[A],
) -> Result<&[B], PodCastError> {
let input_bytes = core::mem::size_of_val::<[A]>(a);
// Note(Lokathor): everything with `align_of` and `size_of` will optimize away
// after monomorphization.
if align_of::<B>() > align_of::<A>()
Expand All @@ -361,10 +352,11 @@ pub(crate) unsafe fn try_cast_slice<A: Copy, B: Copy>(
Err(PodCastError::TargetAlignmentGreaterAndInputNotAligned)
} else if size_of::<B>() == size_of::<A>() {
Ok(unsafe { core::slice::from_raw_parts(a.as_ptr() as *const B, a.len()) })
} else if size_of::<A>() == 0 || size_of::<B>() == 0 {
Err(PodCastError::SizeMismatch)
} else if core::mem::size_of_val(a) % size_of::<B>() == 0 {
let new_len = core::mem::size_of_val(a) / size_of::<B>();
} else if (size_of::<B>() != 0 && input_bytes % size_of::<B>() == 0)
|| (size_of::<B>() == 0 && input_bytes == 0)
{
let new_len =
if size_of::<B>() != 0 { input_bytes / size_of::<B>() } else { 0 };
Ok(unsafe { core::slice::from_raw_parts(a.as_ptr() as *const B, new_len) })
} else {
Err(PodCastError::OutputSliceWouldHaveSlop)
Expand All @@ -379,6 +371,7 @@ pub(crate) unsafe fn try_cast_slice<A: Copy, B: Copy>(
pub(crate) unsafe fn try_cast_slice_mut<A: Copy, B: Copy>(
a: &mut [A],
) -> Result<&mut [B], PodCastError> {
let input_bytes = core::mem::size_of_val::<[A]>(a);
// Note(Lokathor): everything with `align_of` and `size_of` will optimize away
// after monomorphization.
if align_of::<B>() > align_of::<A>()
Expand All @@ -389,10 +382,11 @@ pub(crate) unsafe fn try_cast_slice_mut<A: Copy, B: Copy>(
Ok(unsafe {
core::slice::from_raw_parts_mut(a.as_mut_ptr() as *mut B, a.len())
})
} else if size_of::<A>() == 0 || size_of::<B>() == 0 {
Err(PodCastError::SizeMismatch)
} else if core::mem::size_of_val(a) % size_of::<B>() == 0 {
let new_len = core::mem::size_of_val(a) / size_of::<B>();
} else if (size_of::<B>() != 0 && input_bytes % size_of::<B>() == 0)
|| (size_of::<B>() == 0 && input_bytes == 0)
{
let new_len =
if size_of::<B>() != 0 { input_bytes / size_of::<B>() } else { 0 };
Ok(unsafe {
core::slice::from_raw_parts_mut(a.as_mut_ptr() as *mut B, new_len)
})
Expand Down
11 changes: 5 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,14 @@ pub use bytemuck_derive::{
/// The things that can go wrong when casting between [`Pod`] data forms.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PodCastError {
/// You tried to cast a slice to an element type with a higher alignment
/// requirement but the slice wasn't aligned.
/// You tried to cast a reference into a reference to a type with a higher alignment
/// requirement but the input reference wasn't aligned.
TargetAlignmentGreaterAndInputNotAligned,
/// If the element size changes then the output slice changes length
/// accordingly. If the output slice wouldn't be a whole number of elements
/// If the element size of a slice changes, then the output slice changes length
/// accordingly. If the output slice wouldn't be a whole number of elements,
/// then the conversion fails.
OutputSliceWouldHaveSlop,
/// When casting a slice you can't convert between ZST elements and non-ZST
/// elements. When casting an individual `T`, `&T`, or `&mut T` value the
/// When casting an individual `T`, `&T`, or `&mut T` value the
/// source size and destination size must be an exact match.
SizeMismatch,
/// For this type of cast the alignments must be exactly the same and they
Expand Down
34 changes: 27 additions & 7 deletions src/must.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ impl<A, B> Cast<A, B> {
const ASSERT_ALIGN_GREATER_THAN_EQUAL: () =
assert!(align_of::<A>() >= align_of::<B>());
const ASSERT_SIZE_EQUAL: () = assert!(size_of::<A>() == size_of::<B>());
const ASSERT_SIZE_MULTIPLE_OF: () = assert!(
(size_of::<A>() == 0) == (size_of::<B>() == 0)
&& (size_of::<A>() % size_of::<B>() == 0)
const ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST: () = assert!(
(size_of::<A>() == 0)
|| (size_of::<B>() != 0 && size_of::<A>() % size_of::<B>() == 0)
);
}

Expand Down Expand Up @@ -113,15 +113,20 @@ pub fn must_cast_mut<
/// * If the target type has a greater alignment requirement.
/// * If the target element type doesn't evenly fit into the the current element
/// type (eg: 3 `u16` values is 1.5 `u32` values, so that's a failure).
/// * Similarly, you can't convert between a [ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts)
/// and a non-ZST.
/// * Similarly, you can't convert from a non-[ZST](https://doc.rust-lang.org/nomicon/exotic-sizes.html#zero-sized-types-zsts)
/// to a ZST (e.g. 3 `u8` values is not any number of `()` values).
///
/// ## Examples
/// ```
/// let indicies: &[u16] = &[1, 2, 3];
/// // compiles:
/// let bytes: &[u8] = bytemuck::must_cast_slice(indicies);
/// ```
/// ```
/// let zsts: &[()] = &[(), (), ()];
/// // compiles:
/// let bytes: &[u8] = bytemuck::must_cast_slice(zsts);
/// ```
/// ```compile_fail,E0080
/// # let bytes : &[u8] = &[1, 0, 2, 0, 3, 0];
/// // fails to compile (bytes.len() might not be a multiple of 2):
Expand All @@ -132,9 +137,14 @@ pub fn must_cast_mut<
/// // fails to compile (alignment requirements increased):
/// let indicies : &[u16] = bytemuck::must_cast_slice(byte_pairs);
/// ```
/// ```compile_fail,E0080
/// let bytes: &[u8] = &[];
/// // fails to compile: (bytes.len() might not be 0)
/// let zsts: &[()] = bytemuck::must_cast_slice(bytes);
/// ```
#[inline]
pub fn must_cast_slice<A: NoUninit, B: AnyBitPattern>(a: &[A]) -> &[B] {
let _ = Cast::<A, B>::ASSERT_SIZE_MULTIPLE_OF;
let _ = Cast::<A, B>::ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST;
let _ = Cast::<A, B>::ASSERT_ALIGN_GREATER_THAN_EQUAL;
let new_len = if size_of::<A>() == size_of::<B>() {
a.len()
Expand All @@ -156,6 +166,11 @@ pub fn must_cast_slice<A: NoUninit, B: AnyBitPattern>(a: &[A]) -> &[B] {
/// // compiles:
/// let bytes: &mut [u8] = bytemuck::must_cast_slice_mut(indicies);
/// ```
/// ```
/// let zsts: &mut [()] = &mut [(), (), ()];
/// // compiles:
/// let bytes: &mut [u8] = bytemuck::must_cast_slice_mut(zsts);
/// ```
/// ```compile_fail,E0080
/// # let mut bytes = [1, 0, 2, 0, 3, 0];
/// # let bytes : &mut [u8] = &mut bytes[..];
Expand All @@ -168,14 +183,19 @@ pub fn must_cast_slice<A: NoUninit, B: AnyBitPattern>(a: &[A]) -> &[B] {
/// // fails to compile (alignment requirements increased):
/// let indicies : &mut [u16] = bytemuck::must_cast_slice_mut(byte_pairs);
/// ```
/// ```compile_fail,E0080
/// let bytes: &mut [u8] = &mut [];
/// // fails to compile: (bytes.len() might not be 0)
/// let zsts: &mut [()] = bytemuck::must_cast_slice_mut(bytes);
/// ```
#[inline]
pub fn must_cast_slice_mut<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
a: &mut [A],
) -> &mut [B] {
let _ = Cast::<A, B>::ASSERT_SIZE_MULTIPLE_OF;
let _ = Cast::<A, B>::ASSERT_SIZE_MULTIPLE_OF_OR_INPUT_ZST;
let _ = Cast::<A, B>::ASSERT_ALIGN_GREATER_THAN_EQUAL;
let new_len = if size_of::<A>() == size_of::<B>() {
a.len()
Expand Down
39 changes: 30 additions & 9 deletions tests/cast_slice_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,30 @@ fn test_panics() {
should_panic!(from_bytes::<u32>(&aligned_bytes[1..5]));
}

#[test]
fn test_zsts() {
#[derive(Debug, Clone, Copy)]
struct MyZst;
unsafe impl Zeroable for MyZst {}
unsafe impl Pod for MyZst {}
assert_eq!(42, cast_slice::<(), MyZst>(&[(); 42]).len());
assert_eq!(42, cast_slice_mut::<(), MyZst>(&mut [(); 42]).len());
assert_eq!(0, cast_slice::<(), u8>(&[(); 42]).len());
assert_eq!(0, cast_slice_mut::<(), u8>(&mut [(); 42]).len());
assert_eq!(0, cast_slice::<u8, ()>(&[]).len());
assert_eq!(0, cast_slice_mut::<u8, ()>(&mut []).len());

assert_eq!(
PodCastError::OutputSliceWouldHaveSlop,
try_cast_slice::<u8, ()>(&[42]).unwrap_err()
);

assert_eq!(
PodCastError::OutputSliceWouldHaveSlop,
try_cast_slice_mut::<u8, ()>(&mut [42]).unwrap_err()
);
}

#[cfg(feature = "extern_crate_alloc")]
#[test]
fn test_boxed_slices() {
Expand All @@ -209,7 +233,6 @@ fn test_boxed_slices() {
result.expect_err("u16 and i8 have different alignment");
assert_eq!(error, PodCastError::AlignmentMismatch);

// FIXME(#253): Should these next two casts' errors be consistent?
let result: Result<&[[i8; 3]], PodCastError> =
try_cast_slice(&*boxed_i8_slice);
let error =
Expand All @@ -220,7 +243,7 @@ fn test_boxed_slices() {
try_cast_slice_box(boxed_i8_slice);
let (error, boxed_i8_slice) =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

let empty: Box<[()]> = cast_slice_box::<u8, ()>(Box::new([]));
assert!(empty.is_empty());
Expand All @@ -229,7 +252,7 @@ fn test_boxed_slices() {
try_cast_slice_box(boxed_i8_slice);
let (error, boxed_i8_slice) =
result.expect_err("slice of ZST cannot be made from slice of 4 u8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

drop(boxed_i8_slice);

Expand All @@ -254,7 +277,6 @@ fn test_rc_slices() {
result.expect_err("u16 and i8 have different alignment");
assert_eq!(error, PodCastError::AlignmentMismatch);

// FIXME(#253): Should these next two casts' errors be consistent?
let result: Result<&[[i8; 3]], PodCastError> = try_cast_slice(&*rc_i8_slice);
let error =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
Expand All @@ -264,7 +286,7 @@ fn test_rc_slices() {
try_cast_slice_rc(rc_i8_slice);
let (error, rc_i8_slice) =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

let empty: Rc<[()]> = cast_slice_rc::<u8, ()>(Rc::new([]));
assert!(empty.is_empty());
Expand All @@ -273,7 +295,7 @@ fn test_rc_slices() {
try_cast_slice_rc(rc_i8_slice);
let (error, rc_i8_slice) =
result.expect_err("slice of ZST cannot be made from slice of 4 u8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

drop(rc_i8_slice);

Expand All @@ -299,7 +321,6 @@ fn test_arc_slices() {
result.expect_err("u16 and i8 have different alignment");
assert_eq!(error, PodCastError::AlignmentMismatch);

// FIXME(#253): Should these next two casts' errors be consistent?
let result: Result<&[[i8; 3]], PodCastError> = try_cast_slice(&*arc_i8_slice);
let error =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
Expand All @@ -309,7 +330,7 @@ fn test_arc_slices() {
try_cast_slice_arc(arc_i8_slice);
let (error, arc_i8_slice) =
result.expect_err("slice of [i8; 3] cannot be made from slice of 4 i8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

let empty: Arc<[()]> = cast_slice_arc::<u8, ()>(Arc::new([]));
assert!(empty.is_empty());
Expand All @@ -318,7 +339,7 @@ fn test_arc_slices() {
try_cast_slice_arc(arc_i8_slice);
let (error, arc_i8_slice) =
result.expect_err("slice of ZST cannot be made from slice of 4 u8s");
assert_eq!(error, PodCastError::SizeMismatch);
assert_eq!(error, PodCastError::OutputSliceWouldHaveSlop);

drop(arc_i8_slice);

Expand Down

0 comments on commit 291a924

Please sign in to comment.