Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Xtensa] Fix FP mul-sub fusion (LLVM-276) #76

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions llvm/lib/Target/Xtensa/XtensaISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,20 @@ MVT XtensaTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}

bool XtensaTargetLowering::isFNegFree(EVT VT) const {
if (!VT.isSimple())
return false;

switch (VT.getSimpleVT().SimpleTy) {
case MVT::f32:
return Subtarget.hasSingleFloat();
default:
break;
}

return false;
}

bool XtensaTargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
EVT VT) const {
if (!VT.isSimple())
Expand Down Expand Up @@ -512,33 +526,38 @@ void XtensaTargetLowering::LowerAsmOperandForConstraint(

static SDValue performMADD_MSUBCombine(SDNode *ROOTNode, SelectionDAG &CurDAG,
const XtensaSubtarget &Subtarget) {
if (ROOTNode->getOperand(0).getValueType() != MVT::f32)
return SDValue();
SDValue LHS = ROOTNode->getOperand(0);
SDValue RHS = ROOTNode->getOperand(1);

if (ROOTNode->getOperand(0).getOpcode() != ISD::FMUL &&
ROOTNode->getOperand(1).getOpcode() != ISD::FMUL)
if (LHS.getValueType() != MVT::f32 || (LHS.getOpcode() != ISD::FMUL && RHS.getOpcode() != ISD::FMUL))
return SDValue();

SDValue Mult = ROOTNode->getOperand(0).getOpcode() == ISD::FMUL
? ROOTNode->getOperand(0)
: ROOTNode->getOperand(1);
SDLoc DL(ROOTNode);
bool IsAdd = ROOTNode->getOpcode() == ISD::FADD;

SDValue Mult, AddOperand;
bool Inverted;

SDValue AddOperand = ROOTNode->getOperand(0).getOpcode() == ISD::FMUL
? ROOTNode->getOperand(1)
: ROOTNode->getOperand(0);
if (LHS.getOpcode() == ISD::FMUL)
Mult = LHS, AddOperand = RHS, Inverted = false;
else
Mult = RHS, AddOperand = LHS, Inverted = true;

if (!Mult.hasOneUse())
return SDValue();

SDLoc DL(ROOTNode);
SDValue MultOperand0 = Mult->getOperand(0), MultOperand1 = Mult->getOperand(1);

bool IsAdd = ROOTNode->getOpcode() == ISD::FADD;
unsigned Opcode = IsAdd ? XtensaISD::MADD : XtensaISD::MSUB;
SDValue MAddOps[3] = {AddOperand, Mult->getOperand(0), Mult->getOperand(1)};
if (!IsAdd)
if (Inverted)
MultOperand0 = CurDAG.getNode(ISD::FNEG, DL, MVT::f32, MultOperand0);
else
AddOperand = CurDAG.getNode(ISD::FNEG, DL, MVT::f32, AddOperand);

SDValue FMAOps[3] = {MultOperand0, MultOperand1, AddOperand};
EVT VTs[3] = {MVT::f32, MVT::f32, MVT::f32};
SDValue MAdd = CurDAG.getNode(Opcode, DL, VTs, MAddOps);

return MAdd;
return CurDAG.getNode(ISD::FMA, DL, VTs, FMAOps);
}

static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/Xtensa/XtensaISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class XtensaTargetLowering : public TargetLowering {
bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
EVT VT) const override;

bool isFNegFree(EVT VT) const override;

/// If a physical register, this returns the register that receives the
/// exception address on entry to an EH pad.
Register
Expand Down
9 changes: 7 additions & 2 deletions llvm/lib/Target/Xtensa/XtensaInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1112,15 +1112,16 @@ def FLOOR_S : RRR_Inst<0x00, 0x0A, 0x0A, (outs AR:$r), (ins FPR:$s, uimm4:$imm),
let t = imm;
}

def MADDN_S : RRR_Inst<0x00, 0x0A, 0x06, (outs FPR:$r), (ins FPR:$s, FPR:$t),
def MADDN_S : RRR_Inst<0x00, 0x0A, 0x06, (outs FPR:$r), (ins FPR:$a, FPR:$s, FPR:$t),
"maddn.s\t$r, $s, $t", []>, Requires<[HasSingleFloat]> {
let isCommutable = 0;
let Constraints = "$r = $a";
}

// FP multipy-add
def MADD_S : RRR_Inst<0x00, 0x0A, 0x04, (outs FPR:$r), (ins FPR:$a, FPR:$s, FPR:$t),
"madd.s\t$r, $s, $t",
[(set FPR:$r, (Xtensa_madd FPR:$a, FPR:$s, FPR:$t))]>,
[(set FPR:$r, (Xtensa_madd FPR:$a, FPR:$s, FPR:$t))]>,
Requires<[HasSingleFloat]> {
let isCommutable = 0;
let isReMaterializable = 0;
Expand Down Expand Up @@ -1175,6 +1176,10 @@ def MSUB_S : RRR_Inst<0x00, 0x0A, 0x05, (outs FPR:$r), (ins FPR:$a, FPR:$s, FPR:
let Constraints = "$r = $a";
}

// fmsub: -r1 * r2 + r3
def : Pat<(fma (fneg FPR:$r1), FPR:$r2, FPR:$r3),
(MSUB_S $r3, $r1, $r2)>;

def NEXP01_S : RRR_Inst<0x00, 0x0A, 0x0F, (outs FPR:$r), (ins FPR:$s),
"nexp01.s\t$r, $s", []>, Requires<[HasSingleFloat]> {
let t = 0x0B;
Expand Down
147 changes: 147 additions & 0 deletions llvm/test/CodeGen/Xtensa/float-fma.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=xtensa -mcpu=esp32 -verify-machineinstrs < %s \
; RUN: | FileCheck -check-prefix=XTENSA %s

define float @fmadd_s(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fmadd_s:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: madd.s f10, f9, f8
; XTENSA-NEXT: rfr a2, f10
; XTENSA-NEXT: retw.n
%mul = fmul float %a, %b
%add = fadd float %mul, %c
ret float %add
}

define float @fmsub_s(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fmsub_s:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: neg.s f10, f10
; XTENSA-NEXT: madd.s f10, f9, f8
; XTENSA-NEXT: rfr a2, f10
; XTENSA-NEXT: retw.n
%mul = fmul float %a, %b
%sub = fsub float %mul, %c
ret float %sub
}

define float @fnmadd_s(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fnmadd_s:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: madd.s f10, f9, f8
; XTENSA-NEXT: neg.s f8, f10
; XTENSA-NEXT: rfr a2, f8
; XTENSA-NEXT: retw.n
%mul = fmul float %a, %b
%add = fadd float %mul, %c
%negadd = fneg float %add
ret float %negadd
}


define float @fnmsub_s(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fnmsub_s:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: msub.s f10, f9, f8
; XTENSA-NEXT: rfr a2, f10
; XTENSA-NEXT: retw.n
%nega = fneg float %a
%mul = fmul float %nega, %b
%add = fadd float %mul, %c
ret float %add
}

declare float @llvm.fma.f32(float, float, float)

define float @fmadd_s_intrinsics(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fmadd_s_intrinsics:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: madd.s f10, f9, f8
; XTENSA-NEXT: rfr a2, f10
; XTENSA-NEXT: retw.n
%fma = call float @llvm.fma.f32(float %a, float %b, float %c)
ret float %fma
}

define float @fmsub_s_intrinsics(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fmsub_s_intrinsics:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: neg.s f10, f10
; XTENSA-NEXT: madd.s f10, f9, f8
; XTENSA-NEXT: rfr a2, f10
; XTENSA-NEXT: retw.n
%negc = fneg float %c
%fma = call float @llvm.fma.f32(float %a, float %b, float %negc)
ret float %fma
}

define float @fnmadd_s_intrinsics(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fnmadd_s_intrinsics:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: madd.s f10, f9, f8
; XTENSA-NEXT: neg.s f8, f10
; XTENSA-NEXT: rfr a2, f8
; XTENSA-NEXT: retw.n
%fma = call float @llvm.fma.f32(float %a, float %b, float %c)
%neg = fneg float %fma
ret float %neg
}

define float @fnmsub_s_intrinsics(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fnmsub_s_intrinsics:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: msub.s f10, f9, f8
; XTENSA-NEXT: rfr a2, f10
; XTENSA-NEXT: retw.n
%nega = fneg float %a
%fma = call float @llvm.fma.f32(float %nega, float %b, float %c)
ret float %fma
}

define float @fnmsub_s_swap_intrinsics(float %a, float %b, float %c) nounwind {
; XTENSA-LABEL: fnmsub_s_swap_intrinsics:
; XTENSA: # %bb.0:
; XTENSA-NEXT: entry a1, 32
; XTENSA-NEXT: wfr f8, a3
; XTENSA-NEXT: wfr f9, a2
; XTENSA-NEXT: wfr f10, a4
; XTENSA-NEXT: neg.s f10, f10
; XTENSA-NEXT: madd.s f10, f9, f8
; XTENSA-NEXT: rfr a2, f10
; XTENSA-NEXT: retw.n
%negc = fneg float %c
%fma = call float @llvm.fma.f32(float %a, float %b, float %negc)
andreisfr marked this conversation as resolved.
Show resolved Hide resolved
ret float %fma
}