Skip to content

Commit

Permalink
x64: Add more fma instruction lowerings (bytecodealliance#5846)
Browse files Browse the repository at this point in the history
The relaxed-simd proposal for WebAssembly adds a fused-multiply-add
operation for `v128` types so I was poking around at Cranelift's
existing support for its `fma` instruction. I was also poking around at
the x86_64 ISA's offerings for the FMA operation and ended up with this
PR that improves the lowering of the `fma` instruction on the x64
backend in a number of ways:

* A libcall-based fallback is now provided for `f32x4` and `f64x2` types
  in preparation for eventual support of the relaxed-simd proposal.
  These encodings are horribly slow, but it's expected that if FMA
  semantics must be guaranteed then it's the best that can be done
  without the `fma` feature. Otherwise it'll be up to producers (e.g.
  Wasmtime embedders) whether wasm-level FMA operations should be FMA or
  multiply-then-add.

* In addition to the existing `vfmadd213*` instructions opcodes were
  added for `vfmadd132*`. The `132` variant is selected based on which
  argument can have a sinkable load.

* Any argument in the `fma` CLIF instruction can now have a
  `sinkable_load` and it'll generate a single FMA instruction.

* All `vfnmadd*` opcodes were added as well. These are pattern-matched
  where one of the arguments to the CLIF instruction is an `fneg`. I
  opted to not add a new CLIF instruction here since it seemed like
  pattern matching was easy enough but I'm also not intimately familiar
  with the semantics here so if that's the preferred approach I can do
  that too.
  • Loading branch information
alexcrichton authored Feb 21, 2023
1 parent d82ebcc commit bd3dcd3
Show file tree
Hide file tree
Showing 9 changed files with 719 additions and 78 deletions.
71 changes: 41 additions & 30 deletions cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,18 @@
Vfmadd213sd
Vfmadd213ps
Vfmadd213pd
Vfmadd132ss
Vfmadd132sd
Vfmadd132ps
Vfmadd132pd
Vfnmadd213ss
Vfnmadd213sd
Vfnmadd213ps
Vfnmadd213pd
Vfnmadd132ss
Vfnmadd132sd
Vfnmadd132ps
Vfnmadd132pd
Vcmpps
Vcmppd
Vpsrlw
Expand Down Expand Up @@ -1623,8 +1635,8 @@
(decl use_popcnt (bool) Type)
(extern extractor infallible use_popcnt use_popcnt)

(decl use_fma (bool) Type)
(extern extractor infallible use_fma use_fma)
(decl pure use_fma () bool)
(extern constructor use_fma use_fma)

(decl use_sse41 (bool) Type)
(extern extractor infallible use_sse41 use_sse41)
Expand Down Expand Up @@ -3598,34 +3610,33 @@
(_ Unit (emit (MInst.XmmRmRVex3 op src1 src2 src3 dst))))
dst))

;; Helper for creating `vfmadd213ss` instructions.
; TODO: This should have the (Xmm Xmm XmmMem) signature
; but we don't support VEX memory encodings yet
(decl x64_vfmadd213ss (Xmm Xmm Xmm) Xmm)
(rule (x64_vfmadd213ss x y z)
(xmm_rmr_vex3 (AvxOpcode.Vfmadd213ss) x y z))

;; Helper for creating `vfmadd213sd` instructions.
; TODO: This should have the (Xmm Xmm XmmMem) signature
; but we don't support VEX memory encodings yet
(decl x64_vfmadd213sd (Xmm Xmm Xmm) Xmm)
(rule (x64_vfmadd213sd x y z)
(xmm_rmr_vex3 (AvxOpcode.Vfmadd213sd) x y z))

;; Helper for creating `vfmadd213ps` instructions.
; TODO: This should have the (Xmm Xmm XmmMem) signature
; but we don't support VEX memory encodings yet
(decl x64_vfmadd213ps (Xmm Xmm Xmm) Xmm)
(rule (x64_vfmadd213ps x y z)
(xmm_rmr_vex3 (AvxOpcode.Vfmadd213ps) x y z))

;; Helper for creating `vfmadd213pd` instructions.
; TODO: This should have the (Xmm Xmm XmmMem) signature
; but we don't support VEX memory encodings yet
(decl x64_vfmadd213pd (Xmm Xmm Xmm) Xmm)
(rule (x64_vfmadd213pd x y z)
(xmm_rmr_vex3 (AvxOpcode.Vfmadd213pd) x y z))

;; Helper for creating `vfmadd213*` instructions
(decl x64_vfmadd213 (Type Xmm Xmm XmmMem) Xmm)
(rule (x64_vfmadd213 $F32 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd213ss) a b c))
(rule (x64_vfmadd213 $F64 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd213sd) a b c))
(rule (x64_vfmadd213 $F32X4 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd213ps) a b c))
(rule (x64_vfmadd213 $F64X2 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd213pd) a b c))

;; Helper for creating `vfmadd132*` instructions
(decl x64_vfmadd132 (Type Xmm Xmm XmmMem) Xmm)
(rule (x64_vfmadd132 $F32 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd132ss) a b c))
(rule (x64_vfmadd132 $F64 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd132sd) a b c))
(rule (x64_vfmadd132 $F32X4 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd132ps) a b c))
(rule (x64_vfmadd132 $F64X2 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd132pd) a b c))

;; Helper for creating `vfnmadd213*` instructions
(decl x64_vfnmadd213 (Type Xmm Xmm XmmMem) Xmm)
(rule (x64_vfnmadd213 $F32 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd213ss) a b c))
(rule (x64_vfnmadd213 $F64 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd213sd) a b c))
(rule (x64_vfnmadd213 $F32X4 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd213ps) a b c))
(rule (x64_vfnmadd213 $F64X2 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd213pd) a b c))

;; Helper for creating `vfnmadd132*` instructions
(decl x64_vfnmadd132 (Type Xmm Xmm XmmMem) Xmm)
(rule (x64_vfnmadd132 $F32 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd132ss) a b c))
(rule (x64_vfnmadd132 $F64 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd132sd) a b c))
(rule (x64_vfnmadd132 $F32X4 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd132ps) a b c))
(rule (x64_vfnmadd132 $F64X2 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd132pd) a b c))

;; Helper for creating `sqrtss` instructions.
(decl x64_sqrtss (XmmMem) Xmm)
Expand Down
14 changes: 13 additions & 1 deletion cranelift/codegen/src/isa/x64/inst/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,19 @@ impl AvxOpcode {
AvxOpcode::Vfmadd213ss
| AvxOpcode::Vfmadd213sd
| AvxOpcode::Vfmadd213ps
| AvxOpcode::Vfmadd213pd => smallvec![InstructionSet::FMA],
| AvxOpcode::Vfmadd213pd
| AvxOpcode::Vfmadd132ss
| AvxOpcode::Vfmadd132sd
| AvxOpcode::Vfmadd132ps
| AvxOpcode::Vfmadd132pd
| AvxOpcode::Vfnmadd213ss
| AvxOpcode::Vfnmadd213sd
| AvxOpcode::Vfnmadd213ps
| AvxOpcode::Vfnmadd213pd
| AvxOpcode::Vfnmadd132ss
| AvxOpcode::Vfnmadd132sd
| AvxOpcode::Vfnmadd132ps
| AvxOpcode::Vfnmadd132pd => smallvec![InstructionSet::FMA],
AvxOpcode::Vminps
| AvxOpcode::Vminpd
| AvxOpcode::Vmaxps
Expand Down
42 changes: 28 additions & 14 deletions cranelift/codegen/src/isa/x64/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2281,32 +2281,46 @@ pub(crate) fn emit(
let dst = allocs.next(dst.to_reg().to_reg());
debug_assert_eq!(src1, dst);
let src2 = allocs.next(src2.to_reg());
let src3 = src3.clone().to_reg_mem().with_allocs(allocs);
let src3 = match src3.clone().to_reg_mem().with_allocs(allocs) {
RegMem::Reg { reg } => {
RegisterOrAmode::Register(reg.to_real_reg().unwrap().hw_enc().into())
}
RegMem::Mem { addr } => RegisterOrAmode::Amode(addr.finalize(state, sink)),
};

let (w, map, opcode) = match op {
AvxOpcode::Vfmadd132ss => (false, OpcodeMap::_0F38, 0x99),
AvxOpcode::Vfmadd213ss => (false, OpcodeMap::_0F38, 0xA9),
AvxOpcode::Vfnmadd132ss => (false, OpcodeMap::_0F38, 0x9D),
AvxOpcode::Vfnmadd213ss => (false, OpcodeMap::_0F38, 0xAD),
AvxOpcode::Vfmadd132sd => (true, OpcodeMap::_0F38, 0x99),
AvxOpcode::Vfmadd213sd => (true, OpcodeMap::_0F38, 0xA9),
AvxOpcode::Vfnmadd132sd => (true, OpcodeMap::_0F38, 0x9D),
AvxOpcode::Vfnmadd213sd => (true, OpcodeMap::_0F38, 0xAD),
AvxOpcode::Vfmadd132ps => (false, OpcodeMap::_0F38, 0x98),
AvxOpcode::Vfmadd213ps => (false, OpcodeMap::_0F38, 0xA8),
AvxOpcode::Vfnmadd132ps => (false, OpcodeMap::_0F38, 0x9C),
AvxOpcode::Vfnmadd213ps => (false, OpcodeMap::_0F38, 0xAC),
AvxOpcode::Vfmadd132pd => (true, OpcodeMap::_0F38, 0x98),
AvxOpcode::Vfmadd213pd => (true, OpcodeMap::_0F38, 0xA8),
AvxOpcode::Vfnmadd132pd => (true, OpcodeMap::_0F38, 0x9C),
AvxOpcode::Vfnmadd213pd => (true, OpcodeMap::_0F38, 0xAC),
AvxOpcode::Vblendvps => (false, OpcodeMap::_0F3A, 0x4A),
AvxOpcode::Vblendvpd => (false, OpcodeMap::_0F3A, 0x4B),
AvxOpcode::Vpblendvb => (false, OpcodeMap::_0F3A, 0x4C),
_ => unreachable!(),
};

match src3 {
RegMem::Reg { reg: src } => VexInstruction::new()
.length(VexVectorLength::V128)
.prefix(LegacyPrefixes::_66)
.map(map)
.w(w)
.opcode(opcode)
.reg(dst.to_real_reg().unwrap().hw_enc())
.rm(src.to_real_reg().unwrap().hw_enc())
.vvvv(src2.to_real_reg().unwrap().hw_enc())
.encode(sink),
_ => todo!(),
};
VexInstruction::new()
.length(VexVectorLength::V128)
.prefix(LegacyPrefixes::_66)
.map(map)
.w(w)
.opcode(opcode)
.reg(dst.to_real_reg().unwrap().hw_enc())
.rm(src3)
.vvvv(src2.to_real_reg().unwrap().hw_enc())
.encode(sink);
}

Inst::XmmRmRBlendVex {
Expand Down
11 changes: 0 additions & 11 deletions cranelift/codegen/src/isa/x64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1944,23 +1944,12 @@ fn x64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut OperandCol
src2.get_operands(collector);
}
Inst::XmmRmRVex3 {
op,
src1,
src2,
src3,
dst,
..
} => {
// Vfmadd uses and defs the dst reg, that is not the case with all
// AVX's ops, if you're adding a new op, make sure to correctly define
// register uses.
assert!(
*op == AvxOpcode::Vfmadd213ss
|| *op == AvxOpcode::Vfmadd213sd
|| *op == AvxOpcode::Vfmadd213ps
|| *op == AvxOpcode::Vfmadd213pd
);

collector.reg_use(src1.to_reg());
collector.reg_reuse_def(dst.to_writable_reg(), 0);
collector.reg_use(src2.to_reg());
Expand Down
97 changes: 81 additions & 16 deletions cranelift/codegen/src/isa/x64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -2167,13 +2167,13 @@
;; The above rules automatically sink loads for rhs operands, so additionally
;; add rules for sinking loads with lhs operands.
(rule 1 (lower (has_type $F32 (fadd (sinkable_load x) y)))
(x64_addss y (sink_load x)))
(x64_addss y x))
(rule 1 (lower (has_type $F64 (fadd (sinkable_load x) y)))
(x64_addsd y (sink_load x)))
(x64_addsd y x))
(rule 1 (lower (has_type $F32X4 (fadd (sinkable_load x) y)))
(x64_addps y (sink_load x)))
(x64_addps y x))
(rule 1 (lower (has_type $F64X2 (fadd (sinkable_load x) y)))
(x64_addpd y (sink_load x)))
(x64_addpd y x))

;; Rules for `fsub` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand All @@ -2200,13 +2200,13 @@
;; The above rules automatically sink loads for rhs operands, so additionally
;; add rules for sinking loads with lhs operands.
(rule 1 (lower (has_type $F32 (fmul (sinkable_load x) y)))
(x64_mulss y (sink_load x)))
(x64_mulss y x))
(rule 1 (lower (has_type $F64 (fmul (sinkable_load x) y)))
(x64_mulsd y (sink_load x)))
(x64_mulsd y x))
(rule 1 (lower (has_type $F32X4 (fmul (sinkable_load x) y)))
(x64_mulps y (sink_load x)))
(x64_mulps y x))
(rule 1 (lower (has_type $F64X2 (fmul (sinkable_load x) y)))
(x64_mulpd y (sink_load x)))
(x64_mulpd y x))

;; Rules for `fdiv` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand Down Expand Up @@ -2438,18 +2438,83 @@

;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; Base case for fma is to call out to one of two libcalls. For vectors they
;; need to be decomposed, handle each element individually, and then recomposed.

(rule (lower (has_type $F32 (fma x y z)))
(libcall_3 (LibCall.FmaF32) x y z))
(rule (lower (has_type $F64 (fma x y z)))
(libcall_3 (LibCall.FmaF64) x y z))
(rule 1 (lower (has_type (and (use_fma $true) $F32) (fma x y z)))
(x64_vfmadd213ss x y z))
(rule 1 (lower (has_type (and (use_fma $true) $F64) (fma x y z)))
(x64_vfmadd213sd x y z))
(rule (lower (has_type (and (use_fma $true) $F32X4) (fma x y z)))
(x64_vfmadd213ps x y z))
(rule (lower (has_type (and (use_fma $true) $F64X2) (fma x y z)))
(x64_vfmadd213pd x y z))

(rule (lower (has_type $F32X4 (fma x y z)))
(let (
(x Xmm (put_in_xmm x))
(y Xmm (put_in_xmm y))
(z Xmm (put_in_xmm z))
(x0 Xmm (libcall_3 (LibCall.FmaF32) x y z))
(x1 Xmm (libcall_3 (LibCall.FmaF32)
(x64_pshufd x 1)
(x64_pshufd y 1)
(x64_pshufd z 1)))
(x2 Xmm (libcall_3 (LibCall.FmaF32)
(x64_pshufd x 2)
(x64_pshufd y 2)
(x64_pshufd z 2)))
(x3 Xmm (libcall_3 (LibCall.FmaF32)
(x64_pshufd x 3)
(x64_pshufd y 3)
(x64_pshufd z 3)))

(tmp Xmm (vec_insert_lane $F32X4 x0 x1 1))
(tmp Xmm (vec_insert_lane $F32X4 tmp x2 2))
(tmp Xmm (vec_insert_lane $F32X4 tmp x3 3))
)
tmp))
(rule (lower (has_type $F64X2 (fma x y z)))
(let (
(x Xmm (put_in_xmm x))
(y Xmm (put_in_xmm y))
(z Xmm (put_in_xmm z))
(x0 Xmm (libcall_3 (LibCall.FmaF64) x y z))
(x1 Xmm (libcall_3 (LibCall.FmaF64)
(x64_pshufd x 0xee)
(x64_pshufd y 0xee)
(x64_pshufd z 0xee)))
)
(vec_insert_lane $F64X2 x0 x1 1)))


;; Special case for when the `fma` feature is active and a native instruction
;; can be used.
(rule 1 (lower (has_type ty (fma x y z)))
(if-let $true (use_fma))
(fmadd ty x y z))

(decl fmadd (Type Value Value Value) Xmm)
(decl fnmadd (Type Value Value Value) Xmm)

;; Base case. Note that this will automatically sink a load with `z`, the value
;; to add.
(rule (fmadd ty x y z) (x64_vfmadd213 ty x y z))

;; Allow sinking loads with one of the two values being multiplied in addition
;; to the value being added. Note that both x and y can be sunk here due to
;; multiplication being commutative.
(rule 1 (fmadd ty (sinkable_load x) y z) (x64_vfmadd132 ty y z x))
(rule 2 (fmadd ty x (sinkable_load y) z) (x64_vfmadd132 ty x z y))

;; If one of the values being multiplied is negated then use a `vfnmadd*`
;; instruction instead
(rule 3 (fmadd ty (fneg x) y z) (fnmadd ty x y z))
(rule 4 (fmadd ty x (fneg y) z) (fnmadd ty x y z))

(rule (fnmadd ty x y z) (x64_vfnmadd213 ty x y z))
(rule 1 (fnmadd ty (sinkable_load x) y z) (x64_vfnmadd132 ty y z x))
(rule 2 (fnmadd ty x (sinkable_load y) z) (x64_vfnmadd132 ty x z y))

;; Like `fmadd` if one argument is negated switch which one is being codegen'd
(rule 3 (fnmadd ty (fneg x) y z) (fmadd ty x y z))
(rule 4 (fnmadd ty x (fneg y) z) (fmadd ty x y z))

;; Rules for `load*` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand Down
2 changes: 1 addition & 1 deletion cranelift/codegen/src/isa/x64/lower/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ impl Context for IsleContext<'_, '_, MInst, X64Backend> {
}

#[inline]
fn use_fma(&mut self, _: Type) -> bool {
fn use_fma(&mut self) -> bool {
self.backend.x64_flags.use_fma()
}

Expand Down
Loading

0 comments on commit bd3dcd3

Please sign in to comment.