Skip to content

Commit

Permalink
x64: Migrate iadd_pairwise to ISLE (#4718)
Browse files Browse the repository at this point in the history
* Add a test for iadd_pairwise with swiden input

* Implement iadd_pairwise for swiden_{low,high} input

* Add a test case for iadd_pairwise with uwiden input

* Implement iadd_pairwise with uwiden
  • Loading branch information
elliottt authored Aug 16, 2022
1 parent bc8e36a commit fbfceae
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 160 deletions.
18 changes: 18 additions & 0 deletions cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -2582,6 +2582,10 @@
dst))))
dst))

(decl x64_pmaddubsw (Xmm XmmMem) Xmm)
(rule (x64_pmaddubsw src1 src2)
(xmm_rm_r $I8X16 (SseOpcode.Pmaddubsw) src1 src2))

;; Helper for creating `insertps` instructions.
(decl x64_insertps (Xmm XmmMem u8) Xmm)
(rule (x64_insertps src1 src2 lane)
Expand Down Expand Up @@ -3255,6 +3259,20 @@
(ConsumesFlags.ConsumesFlagsSideEffect
(MInst.JmpTableSeq idx tmp1 tmp2 default_target jt_targets)))))

;;;; iadd_pairwise constants ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(decl iadd_pairwise_mul_const_16 () VCodeConstant)
(extern constructor iadd_pairwise_mul_const_16 iadd_pairwise_mul_const_16)

(decl iadd_pairwise_mul_const_32 () VCodeConstant)
(extern constructor iadd_pairwise_mul_const_32 iadd_pairwise_mul_const_32)

(decl iadd_pairwise_xor_const_32 () VCodeConstant)
(extern constructor iadd_pairwise_xor_const_32 iadd_pairwise_xor_const_32)

(decl iadd_pairwise_addd_const_32 () VCodeConstant)
(extern constructor iadd_pairwise_addd_const_32 iadd_pairwise_addd_const_32)

;;;; Comparisons ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(type IcmpCondResult (enum (Condition (producer ProducesFlags) (cc CC))))
Expand Down
37 changes: 37 additions & 0 deletions cranelift/codegen/src/isa/x64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -3189,3 +3189,40 @@
;; Add this second set of converted lanes to the original to properly handle
;; values greater than max signed int.
(x64_paddd tmp1 dst)))

;; Rules for `iadd_pairwise` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower
(has_type $I16X8 (iadd_pairwise
(swiden_low val @ (value_type $I8X16))
(swiden_high val))))
(let ((mul_const Xmm (x64_xmm_load_const $I8X16 (iadd_pairwise_mul_const_16))))
(x64_pmaddubsw mul_const val)))

(rule (lower
(has_type $I32X4 (iadd_pairwise
(swiden_low val @ (value_type $I16X8))
(swiden_high val))))
(let ((mul_const Xmm (x64_xmm_load_const $I16X8 (iadd_pairwise_mul_const_32))))
(x64_pmaddwd val mul_const)))

(rule (lower
(has_type $I16X8 (iadd_pairwise
(uwiden_low val @ (value_type $I8X16))
(uwiden_high val))))
(let ((mul_const Xmm (x64_xmm_load_const $I8X16 (iadd_pairwise_mul_const_16))))
(x64_pmaddubsw val mul_const)))

(rule (lower
(has_type $I32X4 (iadd_pairwise
(uwiden_low val @ (value_type $I16X8))
(uwiden_high val))))
(let ((xor_const Xmm (x64_xmm_load_const $I16X8 (iadd_pairwise_xor_const_32)))
(dst Xmm (x64_pxor val xor_const))

(madd_const Xmm (x64_xmm_load_const $I16X8 (iadd_pairwise_mul_const_32)))
(dst Xmm (x64_pmaddwd dst madd_const))

(addd_const Xmm (x64_xmm_load_const $I16X8 (iadd_pairwise_addd_const_32))))
(x64_paddd dst addd_const)))

162 changes: 2 additions & 160 deletions cranelift/codegen/src/isa/x64/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,169 +561,11 @@ fn lower_insn_to_regs(
| Opcode::FcvtToUint
| Opcode::FcvtToSint
| Opcode::FcvtToUintSat
| Opcode::FcvtToSintSat => {
| Opcode::FcvtToSintSat
| Opcode::IaddPairwise => {
implemented_in_isle(ctx);
}

Opcode::IaddPairwise => {
if let (Some(swiden_low), Some(swiden_high)) = (
matches_input(ctx, inputs[0], Opcode::SwidenLow),
matches_input(ctx, inputs[1], Opcode::SwidenHigh),
) {
let swiden_input = &[
InsnInput {
insn: swiden_low,
input: 0,
},
InsnInput {
insn: swiden_high,
input: 0,
},
];

let input_ty = ctx.input_ty(swiden_low, 0);
let output_ty = ctx.output_ty(insn, 0);
let src0 = put_input_in_reg(ctx, swiden_input[0]);
let src1 = put_input_in_reg(ctx, swiden_input[1]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
if src0 != src1 {
unimplemented!(
"iadd_pairwise not implemented for general case with different inputs"
);
}
match (input_ty, output_ty) {
(types::I8X16, types::I16X8) => {
static MUL_CONST: [u8; 16] = [0x01; 16];
let mul_const = ctx.use_constant(VCodeConstantData::WellKnown(&MUL_CONST));
let mul_const_reg = ctx.alloc_tmp(types::I8X16).only_reg().unwrap();
ctx.emit(Inst::xmm_load_const(mul_const, mul_const_reg, types::I8X16));
ctx.emit(Inst::xmm_mov(
SseOpcode::Movdqa,
RegMem::reg(mul_const_reg.to_reg()),
dst,
));
ctx.emit(Inst::xmm_rm_r(SseOpcode::Pmaddubsw, RegMem::reg(src0), dst));
}
(types::I16X8, types::I32X4) => {
static MUL_CONST: [u8; 16] = [
0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00,
0x01, 0x00, 0x01, 0x00,
];
let mul_const = ctx.use_constant(VCodeConstantData::WellKnown(&MUL_CONST));
let mul_const_reg = ctx.alloc_tmp(types::I16X8).only_reg().unwrap();
ctx.emit(Inst::xmm_load_const(mul_const, mul_const_reg, types::I16X8));
ctx.emit(Inst::xmm_mov(SseOpcode::Movdqa, RegMem::reg(src0), dst));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Pmaddwd,
RegMem::reg(mul_const_reg.to_reg()),
dst,
));
}
_ => {
unimplemented!("Type not supported for {:?}", op);
}
}
} else if let (Some(uwiden_low), Some(uwiden_high)) = (
matches_input(ctx, inputs[0], Opcode::UwidenLow),
matches_input(ctx, inputs[1], Opcode::UwidenHigh),
) {
let uwiden_input = &[
InsnInput {
insn: uwiden_low,
input: 0,
},
InsnInput {
insn: uwiden_high,
input: 0,
},
];

let input_ty = ctx.input_ty(uwiden_low, 0);
let output_ty = ctx.output_ty(insn, 0);
let src0 = put_input_in_reg(ctx, uwiden_input[0]);
let src1 = put_input_in_reg(ctx, uwiden_input[1]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
if src0 != src1 {
unimplemented!(
"iadd_pairwise not implemented for general case with different inputs"
);
}
match (input_ty, output_ty) {
(types::I8X16, types::I16X8) => {
static MUL_CONST: [u8; 16] = [0x01; 16];
let mul_const = ctx.use_constant(VCodeConstantData::WellKnown(&MUL_CONST));
let mul_const_reg = ctx.alloc_tmp(types::I8X16).only_reg().unwrap();
ctx.emit(Inst::xmm_load_const(mul_const, mul_const_reg, types::I8X16));
ctx.emit(Inst::xmm_mov(SseOpcode::Movdqa, RegMem::reg(src0), dst));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Pmaddubsw,
RegMem::reg(mul_const_reg.to_reg()),
dst,
));
}
(types::I16X8, types::I32X4) => {
static PXOR_CONST: [u8; 16] = [
0x00, 0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80,
0x00, 0x80, 0x00, 0x80,
];
let pxor_const =
ctx.use_constant(VCodeConstantData::WellKnown(&PXOR_CONST));
let pxor_const_reg = ctx.alloc_tmp(types::I16X8).only_reg().unwrap();
ctx.emit(Inst::xmm_load_const(
pxor_const,
pxor_const_reg,
types::I16X8,
));
ctx.emit(Inst::xmm_mov(SseOpcode::Movdqa, RegMem::reg(src0), dst));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Pxor,
RegMem::reg(pxor_const_reg.to_reg()),
dst,
));

static MADD_CONST: [u8; 16] = [
0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00,
0x01, 0x00, 0x01, 0x00,
];
let madd_const =
ctx.use_constant(VCodeConstantData::WellKnown(&MADD_CONST));
let madd_const_reg = ctx.alloc_tmp(types::I8X16).only_reg().unwrap();
ctx.emit(Inst::xmm_load_const(
madd_const,
madd_const_reg,
types::I16X8,
));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Pmaddwd,
RegMem::reg(madd_const_reg.to_reg()),
dst,
));
static ADDD_CONST2: [u8; 16] = [
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00,
0x00, 0x00, 0x01, 0x00,
];
let addd_const2 =
ctx.use_constant(VCodeConstantData::WellKnown(&ADDD_CONST2));
let addd_const2_reg = ctx.alloc_tmp(types::I8X16).only_reg().unwrap();
ctx.emit(Inst::xmm_load_const(
addd_const2,
addd_const2_reg,
types::I16X8,
));
ctx.emit(Inst::xmm_rm_r(
SseOpcode::Paddd,
RegMem::reg(addd_const2_reg.to_reg()),
dst,
));
}
_ => {
unimplemented!("Type not supported for {:?}", op);
}
}
} else {
unimplemented!("Operands not supported for {:?}", op);
}
}
Opcode::UwidenHigh | Opcode::UwidenLow | Opcode::SwidenHigh | Opcode::SwidenLow => {
let input_ty = ctx.input_ty(insn, 0);
let output_ty = ctx.output_ty(insn, 0);
Expand Down
38 changes: 38 additions & 0 deletions cranelift/codegen/src/isa/x64/lower/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,30 @@ impl Context for IsleContext<'_, '_, MInst, Flags, IsaFlags, 6> {
self.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&UINT_MASK_HIGH))
}

#[inline]
fn iadd_pairwise_mul_const_16(&mut self) -> VCodeConstant {
self.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&IADD_PAIRWISE_MUL_CONST_16))
}

#[inline]
fn iadd_pairwise_mul_const_32(&mut self) -> VCodeConstant {
self.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&IADD_PAIRWISE_MUL_CONST_32))
}

#[inline]
fn iadd_pairwise_xor_const_32(&mut self) -> VCodeConstant {
self.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&IADD_PAIRWISE_XOR_CONST_32))
}

#[inline]
fn iadd_pairwise_addd_const_32(&mut self) -> VCodeConstant {
self.lower_ctx
.use_constant(VCodeConstantData::WellKnown(&IADD_PAIRWISE_ADDD_CONST_32))
}
}

impl IsleContext<'_, '_, MInst, Flags, IsaFlags, 6> {
Expand Down Expand Up @@ -907,3 +931,17 @@ const UINT_MASK: [u8; 16] = [
const UINT_MASK_HIGH: [u8; 16] = [
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x30, 0x43, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x30, 0x43,
];

const IADD_PAIRWISE_MUL_CONST_16: [u8; 16] = [0x01; 16];

const IADD_PAIRWISE_MUL_CONST_32: [u8; 16] = [
0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00,
];

const IADD_PAIRWISE_XOR_CONST_32: [u8; 16] = [
0x00, 0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x80,
];

const IADD_PAIRWISE_ADDD_CONST_32: [u8; 16] = [
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00,
];
77 changes: 77 additions & 0 deletions cranelift/filetests/filetests/isa/x64/simd-pairwise-add.clif
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
test compile precise-output
target x86_64

function %fn1(i8x16) -> i16x8 {
block0(v0: i8x16):
v1 = swiden_low v0
v2 = swiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}

; pushq %rbp
; movq %rsp, %rbp
; block0:
; movdqa %xmm0, %xmm5
; load_const VCodeConstant(0), %xmm0
; movdqa %xmm5, %xmm6
; pmaddubsw %xmm0, %xmm6, %xmm0
; movq %rbp, %rsp
; popq %rbp
; ret

function %fn2(i16x8) -> i32x4 {
block0(v0: i16x8):
v1 = swiden_low v0
v2 = swiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}

; pushq %rbp
; movq %rsp, %rbp
; block0:
; load_const VCodeConstant(0), %xmm3
; pmaddwd %xmm0, %xmm3, %xmm0
; movq %rbp, %rsp
; popq %rbp
; ret

function %fn3(i8x16) -> i16x8 {
block0(v0: i8x16):
v1 = uwiden_low v0
v2 = uwiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}

; pushq %rbp
; movq %rsp, %rbp
; block0:
; load_const VCodeConstant(0), %xmm3
; pmaddubsw %xmm0, %xmm3, %xmm0
; movq %rbp, %rsp
; popq %rbp
; ret

function %fn4(i16x8) -> i32x4 {
block0(v0: i16x8):
v1 = uwiden_low v0
v2 = uwiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}

; pushq %rbp
; movq %rsp, %rbp
; block0:
; load_const VCodeConstant(0), %xmm3
; pxor %xmm0, %xmm3, %xmm0
; load_const VCodeConstant(1), %xmm7
; pmaddwd %xmm0, %xmm7, %xmm0
; load_const VCodeConstant(2), %xmm11
; paddd %xmm0, %xmm11, %xmm0
; movq %rbp, %rsp
; popq %rbp
; ret

0 comments on commit fbfceae

Please sign in to comment.