Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mono][Arm64] Instrinsify methods for Vector4 on Arm64 #72124

Merged
merged 12 commits into from
Jul 27, 2022
31 changes: 26 additions & 5 deletions src/mono/mono/mini/mini-llvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,10 @@ ovr_tag_from_mono_vector_class (MonoClass *klass) {
case 8: ret |= INTRIN_vector64; break;
case 16: ret |= INTRIN_vector128; break;
}

if (!strcmp ("Vector4", m_class_get_name (klass)) || !strcmp ("Vector2", m_class_get_name (klass)))
return ret | INTRIN_float32;

MonoType *etype = mono_class_get_context (klass)->class_inst->type_argv [0];
switch (etype->type) {
case MONO_TYPE_I1: case MONO_TYPE_U1: ret |= INTRIN_int8; break;
Expand Down Expand Up @@ -1419,9 +1423,9 @@ convert_full (EmitContext *ctx, LLVMValueRef v, LLVMTypeRef dtype, gboolean is_u

if (LLVMGetTypeKind (stype) == LLVMPointerTypeKind && LLVMGetTypeKind (dtype) == LLVMPointerTypeKind)
return LLVMBuildBitCast (ctx->builder, v, dtype, "");
if (LLVMGetTypeKind (dtype) == LLVMPointerTypeKind)
if (LLVMGetTypeKind (dtype) == LLVMPointerTypeKind && LLVMGetTypeKind (stype) == LLVMIntegerTypeKind)
return LLVMBuildIntToPtr (ctx->builder, v, dtype, "");
if (LLVMGetTypeKind (stype) == LLVMPointerTypeKind)
if (LLVMGetTypeKind (stype) == LLVMPointerTypeKind && LLVMGetTypeKind (dtype) == LLVMIntegerTypeKind)
return LLVMBuildPtrToInt (ctx->builder, v, dtype, "");

if (mono_arch_is_soft_float ()) {
Expand Down Expand Up @@ -4092,6 +4096,7 @@ emit_entry_bb (EmitContext *ctx, LLVMBuilderRef builder)
// FIXME: Enabling this fails on windows
case LLVMArgVtypeAddr:
case LLVMArgVtypeByRef:
case LLVMArgAsFpArgs:
{
if (MONO_CLASS_IS_SIMD (ctx->cfg, mono_class_from_mono_type_internal (ainfo->type)))
/* Treat these as normal values */
Expand Down Expand Up @@ -4793,6 +4798,9 @@ process_call (EmitContext *ctx, MonoBasicBlock *bb, LLVMBuilderRef *builder_ref,
if (!addresses [call->inst.dreg])
addresses [call->inst.dreg] = build_alloca_address (ctx, sig->ret);
LLVMBuildStore (builder, lcall, convert_full (ctx, addresses [call->inst.dreg]->value, pointer_type (LLVMTypeOf (lcall)), FALSE));

load_name = "process_call_fp_struct";
should_promote_to_value = is_simd;
break;
case LLVMArgVtypeByVal:
/*
Expand Down Expand Up @@ -5993,10 +6001,23 @@ process_bb (EmitContext *ctx, MonoBasicBlock *bb)
case LLVMArgAsIArgs:
case LLVMArgFpStruct: {
LLVMTypeRef ret_type = LLVMGetReturnType (LLVMGetElementType (LLVMTypeOf (method)));
LLVMValueRef retval;
LLVMValueRef retval, elem;
gboolean is_simd = MONO_CLASS_IS_SIMD (ctx->cfg, mono_class_from_mono_type_internal (sig->ret));

g_assert (addresses [ins->sreg1]);
retval = LLVMBuildLoad2 (builder, ret_type, convert (ctx, addresses [ins->sreg1]->value, pointer_type (ret_type)), "");
if (is_simd) {
g_assert (lhs);
retval = LLVMConstNull(ret_type);

int len = LLVMGetVectorSize (LLVMTypeOf (lhs));
for (int i = 0; i < len; i++)
{
elem = LLVMBuildExtractElement (builder, lhs, const_int32 (i), "extract_elem");
retval = LLVMBuildInsertValue (builder, retval, elem, i, "insert_val_struct");
}
} else{
g_assert (addresses [ins->sreg1]);
retval = LLVMBuildLoad2 (builder, ret_type, convert (ctx, addresses [ins->sreg1]->value, pointer_type (ret_type)), "");
}
LLVMBuildRet (builder, retval);
break;
}
Expand Down
2 changes: 1 addition & 1 deletion src/mono/mono/mini/mini-runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -4351,7 +4351,7 @@ init_class (MonoClass *klass)

const char *name = m_class_get_name (klass);

#ifdef TARGET_AMD64
fanyang-mono marked this conversation as resolved.
Show resolved Hide resolved
#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
/*
* Some of the intrinsics used by the VectorX classes are only implemented on amd64.
* The JIT can't handle SIMD types with != 16 size yet.
Expand Down
103 changes: 55 additions & 48 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -306,18 +306,20 @@ emit_simd_ins_for_binary_op (MonoCompile *cfg, MonoClass *klass, MonoMethodSigna
break;
case SN_Multiply:
case SN_op_Multiply:
if (fsig->params [1]->type != MONO_TYPE_GENERICINST) {
MonoInst* ins = emit_simd_ins (cfg, klass, OP_CREATE_SCALAR_UNSAFE, args [1]->dreg, -1);
ins->inst_c1 = arg_type;
ins = emit_simd_ins (cfg, klass, OP_XBINOP_BYSCALAR, args [0]->dreg, ins->dreg);
ins->inst_c0 = OP_FMUL;
return ins;
} else if (fsig->params [0]->type != MONO_TYPE_GENERICINST) {
MonoInst* ins = emit_simd_ins (cfg, klass, OP_CREATE_SCALAR_UNSAFE, args [0]->dreg, -1);
ins->inst_c1 = arg_type;
ins = emit_simd_ins (cfg, klass, OP_XBINOP_BYSCALAR, ins->dreg, args [1]->dreg);
ins->inst_c0 = OP_FMUL;
return ins;
if (strcmp ("Vector4", m_class_get_name (klass)) && strcmp ("Vector2", m_class_get_name (klass))) {
if (fsig->params [1]->type != MONO_TYPE_GENERICINST) {
MonoInst* ins = emit_simd_ins (cfg, klass, OP_CREATE_SCALAR_UNSAFE, args [1]->dreg, -1);
ins->inst_c1 = arg_type;
ins = emit_simd_ins (cfg, klass, OP_XBINOP_BYSCALAR, args [0]->dreg, ins->dreg);
ins->inst_c0 = OP_FMUL;
return ins;
} else if (fsig->params [0]->type != MONO_TYPE_GENERICINST) {
MonoInst* ins = emit_simd_ins (cfg, klass, OP_CREATE_SCALAR_UNSAFE, args [0]->dreg, -1);
ins->inst_c1 = arg_type;
ins = emit_simd_ins (cfg, klass, OP_XBINOP_BYSCALAR, ins->dreg, args [1]->dreg);
ins->inst_c0 = OP_FMUL;
return ins;
}
}
instc0 = OP_FMUL;
break;
Expand Down Expand Up @@ -512,8 +514,15 @@ emit_sum_vector (MonoCompile *cfg, MonoType *vector_type, MonoTypeEnum element_t
{
MonoClass *vector_class = mono_class_from_mono_type_internal (vector_type);
int vector_size = mono_class_value_size (vector_class, NULL);
MonoClass *element_class = mono_class_from_mono_type_internal (get_vector_t_elem_type (vector_type));
int element_size = mono_class_value_size (element_class, NULL);
int element_size;
if (!strcmp ("Vector4", m_class_get_name (vector_class)))
element_size = vector_size / 4;
else if (!strcmp ("Vector2", m_class_get_name (vector_class)))
element_size = vector_size / 2;
else {
MonoClass *element_class = mono_class_from_mono_type_internal (get_vector_t_elem_type (vector_type));
element_size = mono_class_value_size (element_class, NULL);
}
gboolean has_single_element = vector_size == element_size;

// If there's just one element we need to extract it instead of summing the whole array
Expand Down Expand Up @@ -783,7 +792,7 @@ emit_vector_create_elementwise (
return ins;
}

#if defined(TARGET_AMD64) || defined(TARGET_ARM64) || defined(TARGET_WASM)
#if defined(TARGET_AMD64) || defined(TARGET_ARM64) || defined(TARGET_WASM)

static int
type_to_xinsert_op (MonoTypeEnum type)
Expand Down Expand Up @@ -1547,20 +1556,20 @@ emit_vector64_vector128_t (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSign
return NULL;
}

#endif // defined(TARGET_AMD64) || defined(TARGET_ARM64)

#ifdef TARGET_AMD64

// System.Numerics.Vector2/Vector3/Vector4
static guint16 vector2_methods[] = {
SN_ctor,
SN_Abs,
SN_Add,
SN_CopyTo,
SN_Divide,
SN_Dot,
SN_GetElement,
SN_Max,
SN_Min,
SN_Multiply,
SN_SquareRoot,
SN_Subtract,
SN_WithElement,
SN_get_Item,
SN_get_One,
Expand Down Expand Up @@ -1713,6 +1722,10 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
ins->inst_c1 = MONO_TYPE_R4;
return ins;
}
case SN_Add:
case SN_Divide:
case SN_Multiply:
case SN_Subtract:
case SN_op_Addition:
case SN_op_Division:
case SN_op_Multiply:
Expand All @@ -1721,34 +1734,13 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
case SN_Min:
if (!(!fsig->hasthis && fsig->param_count == 2 && mono_metadata_type_equal (fsig->ret, type) && mono_metadata_type_equal (fsig->params [0], type) && mono_metadata_type_equal (fsig->params [1], type)))
return NULL;
ins = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, args [1]->dreg);
ins->inst_c1 = etype->type;

switch (id) {
case SN_op_Addition:
ins->inst_c0 = OP_FADD;
break;
case SN_op_Division:
ins->inst_c0 = OP_FDIV;
break;
case SN_op_Multiply:
ins->inst_c0 = OP_FMUL;
break;
case SN_op_Subtraction:
ins->inst_c0 = OP_FSUB;
break;
case SN_Max:
ins->inst_c0 = OP_FMAX;
break;
case SN_Min:
ins->inst_c0 = OP_FMIN;
break;
default:
g_assert_not_reached ();
break;
}
return ins;
return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, MONO_TYPE_R4, id);
case SN_Dot: {
#ifdef TARGET_ARM64
int instc0 = OP_FMUL;
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc0, MONO_TYPE_R4, fsig, args);
return emit_sum_vector (cfg, fsig->params [0], MONO_TYPE_R4, pairwise_multiply);
#elif defined(TARGET_AMD64)
if (!(mini_get_cpu_features (cfg) & MONO_CPU_X86_SSE41))
return NULL;

Expand All @@ -1764,6 +1756,9 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
ins->inst_c1 = MONO_TYPE_R4;
MONO_ADD_INS (cfg->cbb, ins);
return ins;
#else
return NULL;
#endif
}
case SN_Abs: {
// MAX(x,0-x)
Expand Down Expand Up @@ -1791,9 +1786,15 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
return ins;
}
case SN_SquareRoot: {
#ifdef TARGET_ARM64
return emit_simd_ins_for_sig (cfg, klass, OP_XOP_OVR_X_X, INTRINS_AARCH64_ADV_SIMD_FSQRT, MONO_TYPE_R4, fsig, args);
#elif defined(TARGET_AMD64)
ins = emit_simd_ins (cfg, klass, OP_XOP_X_X, args [0]->dreg, -1);
ins->inst_c0 = (IntrinsicId)INTRINS_SSE_SQRT_PS;
return ins;
#else
return NULL;
#endif
}
case SN_CopyTo:
// FIXME:
Expand All @@ -1805,9 +1806,9 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
return NULL;
}

#endif /* TARGET_AMD64 */
#endif // defined(TARGET_AMD64) || defined(TARGET_ARM64) || defined(TARGET_WASM)

#if defined(TARGET_AMD64)
#ifdef TARGET_AMD64

static guint16 vector_methods [] = {
SN_ConvertToDouble,
Expand Down Expand Up @@ -4027,6 +4028,12 @@ arch_emit_simd_intrinsics (const char *class_ns, const char *class_name, MonoCom
return emit_vector64_vector128_t (cfg, cmethod, fsig, args);
}

if (!strcmp (class_ns, "System.Numerics")) {
//if (!strcmp ("Vector2", class_name) || !strcmp ("Vector4", class_name) || !strcmp ("Vector3", class_name))
if (!strcmp ("Vector4", class_name))
return emit_vector_2_3_4 (cfg, cmethod, fsig, args);
}

return NULL;
}
#elif TARGET_AMD64
Expand Down