Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have mono handle the vector as APIs that grow or shrink the vector type #104445

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,6 @@ public static Vector128<T> operator >>>(Vector128<T> value, int shiftCount)
/// <exception cref="NotSupportedException">The type of the vector (<typeparamref name="T" />) is not supported.</exception>
public override bool Equals([NotNullWhen(true)] object? obj) => (obj is Vector128<T> other) && Equals(other);

// Account for floating-point equality around NaN
// This is in a separate method so it can be optimized by the mono interpreter/jiterpreter
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool EqualsFloatingPoint(Vector128<T> lhs, Vector128<T> rhs)
{
Vector128<T> result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs));
return result.AsInt32() == Vector128<int>.AllBitsSet;
}

/// <summary>Determines whether the specified <see cref="Vector128{T}" /> is equal to the current instance.</summary>
/// <param name="other">The <see cref="Vector128{T}" /> to compare with the current instance.</param>
/// <returns><c>true</c> if <paramref name="other" /> is equal to the current instance; otherwise, <c>false</c>.</returns>
Expand All @@ -395,7 +385,8 @@ public bool Equals(Vector128<T> other)
{
if ((typeof(T) == typeof(double)) || (typeof(T) == typeof(float)))
{
return EqualsFloatingPoint(this, other);
Vector128<T> result = Vector128.Equals(this, other) | ~(Vector128.Equals(this, this) | Vector128.Equals(other, other));
return result.AsInt32() == Vector128<int>.AllBitsSet;
}
else
{
Expand Down
27 changes: 0 additions & 27 deletions src/mono/browser/runtime/jiterpreter-trace-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3748,33 +3748,6 @@ function emit_simd_3 (builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrin
builder.appendU8(WasmOpcode.i32_eqz);
append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store);
return true;
case SimdIntrinsic3.V128_R4_FLOAT_EQUALITY:
case SimdIntrinsic3.V128_R8_FLOAT_EQUALITY: {
/*
Vector128<T> result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs));
return result.AsInt32() == Vector128<int>.AllBitsSet;
*/
const isR8 = index === SimdIntrinsic3.V128_R8_FLOAT_EQUALITY,
eqOpcode = isR8 ? WasmSimdOpcode.f64x2_eq : WasmSimdOpcode.f32x4_eq;
builder.local("pLocals");
append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load);
builder.local("math_lhs128", WasmOpcode.tee_local);
append_ldloc(builder, getArgU16(ip, 3), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load);
builder.local("math_rhs128", WasmOpcode.tee_local);
builder.appendSimd(eqOpcode);
builder.local("math_lhs128");
builder.local("math_lhs128");
builder.appendSimd(eqOpcode);
builder.local("math_rhs128");
builder.local("math_rhs128");
builder.appendSimd(eqOpcode);
builder.appendSimd(WasmSimdOpcode.v128_or);
builder.appendSimd(WasmSimdOpcode.v128_not);
builder.appendSimd(WasmSimdOpcode.v128_or);
builder.appendSimd(isR8 ? WasmSimdOpcode.i64x2_all_true : WasmSimdOpcode.i32x4_all_true);
append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store);
return true;
}
case SimdIntrinsic3.V128_I1_SHUFFLE: {
// Detect a constant indices vector and turn it into a const. This allows
// v8 to use a more optimized implementation of the swizzle opcode
Expand Down
2 changes: 2 additions & 0 deletions src/mono/mono/mini/interp/interp-internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#define MINT_STACK_ALIGNMENT (2 * MINT_STACK_SLOT_SIZE)
#define MINT_SIMD_ALIGNMENT (MINT_STACK_ALIGNMENT)
#define SIZEOF_V128 16
#define SIZEOF_V2 8
#define SIZEOF_V3 12

#define INTERP_STACK_SIZE (1024*1024)
#define INTERP_REDZONE_SIZE (8*1024)
Expand Down
13 changes: 11 additions & 2 deletions src/mono/mono/mini/interp/interp-simd-intrins.def
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_OR, interp_v128_o
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY, interp_v128_op_bitwise_equality, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_INEQUALITY, interp_v128_op_bitwise_inequality, -1)

INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY, interp_v128_r4_float_equality, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY, interp_v128_r8_float_equality, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_INSTANCE_EQUALS_R4, interp_v128_instance_equals_r4, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V2_INSTANCE_EQUALS_R4, interp_v2_instance_equals_r4, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V3_INSTANCE_EQUALS_R4, interp_v3_instance_equals_r4, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_INSTANCE_EQUALS_R8, interp_v128_instance_equals_r8, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_INSTANCE_EQUALS_BITWISE, interp_v128_instance_equals_bitwise, -1)

INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR, interp_v128_op_exclusive_or, 81)

Expand All @@ -71,6 +74,12 @@ INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R4_MULTIPLY, interp_v128_
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R4_DIVISION, interp_v128_r4_op_division, 231)

INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V128_BITCAST, interp_v128_bitcast, -1)
INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V128_TO_V2, interp_v128_to_v2, -1)
INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V128_TO_V3, interp_v128_to_v3, -1)
INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V2_TO_V128, interp_v2_to_v128, -1)
INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V2_TO_V3, interp_v2_to_v3, -1)
INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V3_TO_V128, interp_v3_to_v128, -1)
INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V3_TO_V2, interp_v3_to_v2, -1)

INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V128_I1_NEGATION, interp_v128_i1_op_negation, 97)
INTERP_SIMD_INTRINSIC_P_P (INTERP_SIMD_INTRINSIC_V128_I2_NEGATION, interp_v128_i2_op_negation, 129)
Expand Down
105 changes: 100 additions & 5 deletions src/mono/mono/mini/interp/interp-simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <wasm_simd128.h>
#endif

#include <mono/utils/mono-math.h>

#ifdef INTERP_ENABLE_SIMD

gboolean interp_simd_enabled = TRUE;
Expand Down Expand Up @@ -35,6 +37,65 @@ interp_v128_bitcast (gpointer res, gpointer v1)
*(v128_i1*)res = *(v128_i1*)v1;
}

// Vector2 AsVector2(Vector128<float> v1)
static void
interp_v128_to_v2 (gpointer res, gpointer v1)
{
memcpy (res, v1, SIZEOF_V2);
}

// Vector3 AsVector3(Vector128<float> v1)
static void
interp_v128_to_v3 (gpointer res, gpointer v1)
{
memcpy (res, v1, SIZEOF_V3);
}

// Vector128<float> AsVector128(Vector2 v1)
static void
interp_v2_to_v128 (gpointer res, gpointer v1)
{
float *res_typed = (float*)res;
float *v1_typed = (float*)v1;

res_typed [0] = v1_typed [0];
res_typed [1] = v1_typed [1];
res_typed [2] = 0;
res_typed [3] = 0;
}

// Vector3 AsVector3(Vector2 v1)
static void
interp_v2_to_v3 (gpointer res, gpointer v1)
{
float *res_typed = (float*)res;
float *v1_typed = (float*)v1;

res_typed [0] = v1_typed [0];
res_typed [1] = v1_typed [1];
res_typed [2] = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a problem here - if the v3 is in a stack local, it's 16 bytes wide and you might need to zero [3]. I'm not sure whether res can be a non-stack address though...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had code earlier that tried to always handle it as 16-bytes, but it didn't help.

Notably, I wouldn't expect the need to explicitly zero that space anyways as it would be considered padding space and should be ignored when loaded. Otherwise it would risk corrupting other operations, like Sum.

}

// Vector128<float> AsVector128(Vector3 v1)
static void
interp_v3_to_v128 (gpointer res, gpointer v1)
{
float *res_typed = (float*)res;
float *v1_typed = (float*)v1;

res_typed [0] = v1_typed [0];
res_typed [1] = v1_typed [1];
res_typed [2] = v1_typed [2];
res_typed [3] = 0;
}

// Vector2 AsVector128(Vector3 v1)
static void
interp_v3_to_v2 (gpointer res, gpointer v1)
{
memcpy (res, v1, SIZEOF_V2);
}

// op_Addition
static void
interp_v128_i1_op_addition (gpointer res, gpointer v1, gpointer v2)
Expand Down Expand Up @@ -132,29 +193,63 @@ interp_v128_op_bitwise_inequality (gpointer res, gpointer v1, gpointer v2)
*(gint32*)res = 1;
}

// Vector128<float>EqualsFloatingPoint
// Vector128<float>.Equals
static void
interp_v128_r4_float_equality (gpointer res, gpointer v1, gpointer v2)
interp_v128_instance_equals_r4 (gpointer res, gpointer v1, gpointer v2)
{
v128_r4 v1_cast = *(v128_r4*)v1;
v128_r4 v1_cast = **(v128_r4**)v1;
kg marked this conversation as resolved.
Show resolved Hide resolved
v128_r4 v2_cast = *(v128_r4*)v2;
v128_r4 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
memset (&v1_cast, 0xff, SIZEOF_V128);

*(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0;
}

// Vector2.Equals
static void
interp_v2_instance_equals_r4 (gpointer res, gpointer v1, gpointer v2)
{
v128_r4 v1_cast;
interp_v2_to_v128 (&v1_cast, v1);
v128_r4 v2_cast = *(v128_r4*)v2;
v128_r4 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
memset (&v1_cast, 0xff, SIZEOF_V2);

*(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V2) == 0;
}

// Vector3.Equals
static void
interp_v128_r8_float_equality (gpointer res, gpointer v1, gpointer v2)
interp_v3_instance_equals_r4 (gpointer res, gpointer v1, gpointer v2)
{
v128_r8 v1_cast = *(v128_r8*)v1;
v128_r4 v1_cast;
interp_v3_to_v128 (&v1_cast, v1);
v128_r4 v2_cast = *(v128_r4*)v2;
v128_r4 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
memset (&v1_cast, 0xff, SIZEOF_V3);

*(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V3) == 0;
}

// Vector128<double>.Equals
static void
interp_v128_instance_equals_r8 (gpointer res, gpointer v1, gpointer v2)
{
v128_r8 v1_cast = **(v128_r8**)v1;
v128_r8 v2_cast = *(v128_r8*)v2;
v128_r8 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
memset (&v1_cast, 0xff, SIZEOF_V128);

*(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0;
}

// Vector128<T>.Equals, for integer T
static void
interp_v128_instance_equals_bitwise (gpointer res, gpointer v1, gpointer v2)
{
interp_v128_op_bitwise_equality(res, *(v128_i1**)v1, v2);
}

// op_Multiply
static void
interp_v128_i1_op_multiply (gpointer res, gpointer v1, gpointer v2)
Expand Down
6 changes: 4 additions & 2 deletions src/mono/mono/mini/interp/simd-methods.def
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,18 @@ SIMD_METHOD(AsUInt16)
SIMD_METHOD(AsUInt32)
SIMD_METHOD(AsUInt64)
SIMD_METHOD(AsVector)
SIMD_METHOD(AsVector4)
SIMD_METHOD(AsVector128)
SIMD_METHOD(AsVector128Unsafe)
SIMD_METHOD(AsVector2)
SIMD_METHOD(AsVector3)
SIMD_METHOD(AsVector4)
SIMD_METHOD(ConditionalSelect)
SIMD_METHOD(Create)
SIMD_METHOD(CreateScalar)
SIMD_METHOD(CreateScalarUnsafe)

SIMD_METHOD(Equals)
SIMD_METHOD(EqualsAny)
SIMD_METHOD(EqualsFloatingPoint)
SIMD_METHOD(ExtractMostSignificantBits)
SIMD_METHOD(GreaterThan)
SIMD_METHOD(LessThan)
Expand Down
Loading
Loading