Skip to content

Commit

Permalink
Auto merge of #3214 - eduardosm:move-x86-code, r=RalfJung
Browse files Browse the repository at this point in the history
Move some x86 intrinsics code to helper functions in `shims::x86`

To make them reusable for intrinsics of other x86 features.

Splitted from #3192
  • Loading branch information
bors committed Dec 8, 2023
2 parents 33fb35e + 44bf5fc commit a5b9f54
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 265 deletions.
285 changes: 285 additions & 0 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use rand::Rng as _;

use rustc_apfloat::{ieee::Single, Float as _};
use rustc_middle::{mir, ty};
use rustc_span::Symbol;
use rustc_target::abi::Size;
Expand Down Expand Up @@ -331,6 +334,210 @@ fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
Ok(())
}

#[derive(Copy, Clone)]
enum FloatUnaryOp {
/// sqrt(x)
///
/// <https://www.felixcloutier.com/x86/sqrtss>
/// <https://www.felixcloutier.com/x86/sqrtps>
Sqrt,
/// Approximation of 1/x
///
/// <https://www.felixcloutier.com/x86/rcpss>
/// <https://www.felixcloutier.com/x86/rcpps>
Rcp,
/// Approximation of 1/sqrt(x)
///
/// <https://www.felixcloutier.com/x86/rsqrtss>
/// <https://www.felixcloutier.com/x86/rsqrtps>
Rsqrt,
}

/// Performs `which` scalar operation on `op` and returns the result.
#[allow(clippy::arithmetic_side_effects)] // floating point operations without side effects
fn unary_op_f32<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatUnaryOp,
op: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatUnaryOp::Sqrt => {
let op = op.to_scalar();
// FIXME using host floats
Ok(Scalar::from_u32(f32::from_bits(op.to_u32()?).sqrt().to_bits()))
}
FloatUnaryOp::Rcp => {
let op = op.to_scalar().to_f32()?;
let div = (Single::from_u128(1).value / op).value;
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
// inaccuracy of RCP.
let res = apply_random_float_error(this, div, -12);
Ok(Scalar::from_f32(res))
}
FloatUnaryOp::Rsqrt => {
let op = op.to_scalar().to_u32()?;
// FIXME using host floats
let sqrt = Single::from_bits(f32::from_bits(op).sqrt().to_bits().into());
let rsqrt = (Single::from_u128(1).value / sqrt).value;
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
// inaccuracy of RSQRT.
let res = apply_random_float_error(this, rsqrt, -12);
Ok(Scalar::from_f32(res))
}
}
}

/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
#[allow(clippy::arithmetic_side_effects)] // floating point arithmetic cannot panic
fn apply_random_float_error<F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, '_>,
val: F,
err_scale: i32,
) -> F {
let rng = this.machine.rng.get_mut();
// generates rand(0, 2^64) * 2^(scale - 64) = rand(0, 1) * 2^scale
let err =
F::from_u128(rng.gen::<u64>().into()).value.scalbn(err_scale.checked_sub(64).unwrap());
// give it a random sign
let err = if rng.gen::<bool>() { -err } else { err };
// multiple the value with (1+err)
(val * (F::from_u128(1).value + err).value).value
}

/// Performs `which` operation on the first component of `op` and copies
/// the other components. The result is stored in `dest`.
fn unary_op_ss<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatUnaryOp,
op: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

let res0 = unary_op_f32(this, which, &this.read_immediate(&this.project_index(&op, 0)?)?)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;

for i in 1..dest_len {
this.copy_op(
&this.project_index(&op, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}

Ok(())
}

/// Performs `which` operation on each component of `op`, storing the
/// result is stored in `dest`.
fn unary_op_ps<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatUnaryOp,
op: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

for i in 0..dest_len {
let op = this.read_immediate(&this.project_index(&op, i)?)?;
let dest = this.project_index(&dest, i)?;

let res = unary_op_f32(this, which, &op)?;
this.write_scalar(res, &dest)?;
}

Ok(())
}

// Rounds the first element of `right` according to `rounding`
// and copies the remaining elements from `left`.
fn round_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;

let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
let res = op0.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, 0)?,
)?;

for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}

Ok(())
}

// Rounds all elements of `op` according to `rounding`.
fn round_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
op: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;

for i in 0..dest_len {
let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
let res = op.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, i)?,
)?;
}

Ok(())
}

/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
/// `round.{ss,sd,ps,pd}` intrinsics.
fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
// The fourth bit of `rounding` only affects the SSE status
// register, which cannot be accessed from Miri (or from Rust,
// for that matter), so we can ignore it.
match rounding & !0b1000 {
// When the third bit is 0, the rounding mode is determined by the
// first two bits.
0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
0b001 => Ok(rustc_apfloat::Round::TowardNegative),
0b010 => Ok(rustc_apfloat::Round::TowardPositive),
0b011 => Ok(rustc_apfloat::Round::TowardZero),
// When the third bit is 1, the rounding mode is determined by the
// SSE status register. Since we do not support modifying it from
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
}
}

/// Converts each element of `op` from floating point to signed integer.
///
/// When the input value is NaN or out of range, fall back to minimum value.
Expand Down Expand Up @@ -408,3 +615,81 @@ fn horizontal_bin_op<'tcx>(

Ok(())
}

/// Conditionally multiplies the packed floating-point elements in
/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
/// products (up to 4), and conditionally stores the sum in `dest` using
/// the low 4 bits of `imm`.
fn conditional_dot_product<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
imm: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert!(dest_len <= 4);

let imm = this.read_scalar(imm)?.to_u8()?;

let element_layout = left.layout.field(this, 0);

// Calculate dot product
// Elements are floating point numbers, but we can use `from_int`
// because the representation of 0.0 is all zero bits.
let mut sum = ImmTy::from_int(0u8, element_layout);
for i in 0..left_len {
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;

let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
}
}

// Write to destination (conditioned to imm)
for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;

if imm & (1 << i) != 0 {
this.write_immediate(*sum, &dest)?;
} else {
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
}
}

Ok(())
}

/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
fn bin_op_folded<'tcx, T>(
this: &crate::MiriInterpCx<'_, 'tcx>,
lhs: &OpTy<'tcx, Provenance>,
rhs: &OpTy<'tcx, Provenance>,
init: T,
mut f: impl FnMut(T, ImmTy<'tcx, Provenance>, ImmTy<'tcx, Provenance>) -> InterpResult<'tcx, T>,
) -> InterpResult<'tcx, T> {
assert_eq!(lhs.layout, rhs.layout);

let (lhs, lhs_len) = this.operand_to_simd(lhs)?;
let (rhs, rhs_len) = this.operand_to_simd(rhs)?;

assert_eq!(lhs_len, rhs_len);

let mut acc = init;
for i in 0..lhs_len {
let lhs = this.project_index(&lhs, i)?;
let rhs = this.project_index(&rhs, i)?;

let lhs = this.read_immediate(&lhs)?;
let rhs = this.read_immediate(&rhs)?;
acc = f(acc, lhs, rhs)?;
}

Ok(acc)
}
Loading

0 comments on commit a5b9f54

Please sign in to comment.