Skip to content

Commit

Permalink
[RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/…
Browse files Browse the repository at this point in the history
…USAT opcodes (#100173)

Summary:
These new opcodes drop the shift amount, rounding mode, and passthru.
Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding
mode, and passthru are added in isel patterns similar to how we
translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify #99418 a little.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251265
  • Loading branch information
topperc authored and yuxuanchen1997 committed Jul 25, 2024
1 parent 34e9ada commit 4237e3f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 106 deletions.
25 changes: 8 additions & 17 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2997,13 +2997,9 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
CvtEltVT = MVT::getIntegerVT(CvtEltVT.getSizeInBits() / 2);
CvtContainerVT = CvtContainerVT.changeVectorElementType(CvtEltVT);
// Rounding mode here is arbitrary since we aren't shifting out any bits.
unsigned ClipOpc = IsSigned ? RISCVISD::VNCLIP_VL : RISCVISD::VNCLIPU_VL;
Res = DAG.getNode(
ClipOpc, DL, CvtContainerVT,
{Res, DAG.getConstant(0, DL, CvtContainerVT),
DAG.getUNDEF(CvtContainerVT), Mask,
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
VL});
unsigned ClipOpc = IsSigned ? RISCVISD::TRUNCATE_VECTOR_VL_SSAT
: RISCVISD::TRUNCATE_VECTOR_VL_USAT;
Res = DAG.getNode(ClipOpc, DL, CvtContainerVT, Res, Mask, VL);
}

SDValue SplatZero = DAG.getNode(
Expand Down Expand Up @@ -16643,9 +16639,9 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
SDValue Val;
unsigned ClipOpc;
if ((Val = DetectUSatPattern(Src)))
ClipOpc = RISCVISD::VNCLIPU_VL;
ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
else if ((Val = DetectSSatPattern(Src)))
ClipOpc = RISCVISD::VNCLIP_VL;
ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
else
return SDValue();

Expand All @@ -16654,12 +16650,7 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
do {
MVT ValEltVT = MVT::getIntegerVT(ValVT.getScalarSizeInBits() / 2);
ValVT = ValVT.changeVectorElementType(ValEltVT);
// Rounding mode here is arbitrary since we aren't shifting out any bits.
Val = DAG.getNode(
ClipOpc, DL, ValVT,
{Val, DAG.getConstant(0, DL, ValVT), DAG.getUNDEF(VT), Mask,
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
VL});
Val = DAG.getNode(ClipOpc, DL, ValVT, Val, Mask, VL);
} while (ValVT != VT);

return Val;
Expand Down Expand Up @@ -20463,6 +20454,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDE1UP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
Expand Down Expand Up @@ -20506,8 +20499,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UADDSAT_VL)
NODE_NAME_CASE(SSUBSAT_VL)
NODE_NAME_CASE(USUBSAT_VL)
NODE_NAME_CASE(VNCLIP_VL)
NODE_NAME_CASE(VNCLIPU_VL)
NODE_NAME_CASE(FADD_VL)
NODE_NAME_CASE(FSUB_VL)
NODE_NAME_CASE(FMUL_VL)
Expand Down
10 changes: 6 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ enum NodeType : unsigned {
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
// mask and VL operand.
TRUNCATE_VECTOR_VL,
// Truncates a RVV integer vector by one power-of-two. If the value doesn't
// fit in the destination type, the result is saturated. These correspond to
// vnclip and vnclipu with a shift of 0. Carries both an extra mask and VL
// operand.
TRUNCATE_VECTOR_VL_SSAT,
TRUNCATE_VECTOR_VL_USAT,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the XLenVT
// index (either constant or non-constant), the fourth is the mask, the fifth
Expand Down Expand Up @@ -273,10 +279,6 @@ enum NodeType : unsigned {
// Rounding averaging adds of unsigned integers.
AVGCEILU_VL,

// Operands are (source, shift, merge, mask, roundmode, vl)
VNCLIPU_VL,
VNCLIP_VL,

MULHS_VL,
MULHU_VL,
FADD_VL,
Expand Down
115 changes: 30 additions & 85 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;

def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;

def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
Expand Down Expand Up @@ -408,12 +405,17 @@ def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
[(riscv_sext_vl node:$A, node:$B, node:$C),
(riscv_zext_vl node:$A, node:$B, node:$C)]>;

def SDT_RISCVVTRUNCATE_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
SDTCisSameNumEltsAs<0, 1>,
SDTCisSameNumEltsAs<0, 2>,
SDTCVecEltisVT<2, i1>,
SDTCisVT<3, XLenVT>]>;
def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
SDTypeProfile<1, 3, [SDTCisVec<0>,
SDTCisSameNumEltsAs<0, 1>,
SDTCisSameNumEltsAs<0, 2>,
SDTCVecEltisVT<2, i1>,
SDTCisVT<3, XLenVT>]>>;
SDT_RISCVVTRUNCATE_VL>;
def riscv_trunc_vector_vl_ssat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_SSAT",
SDT_RISCVVTRUNCATE_VL>;
def riscv_trunc_vector_vl_usat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_USAT",
SDT_RISCVVTRUNCATE_VL>;

def SDT_RISCVVWIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
SDTCisInt<1>,
Expand Down Expand Up @@ -650,34 +652,6 @@ class VPatBinaryVL_V<SDPatternOperator vop,
op2_reg_class:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;

multiclass VPatBinaryRM_VL_V<SDNode vop,
string instruction_name,
string suffix,
ValueType result_type,
ValueType op1_type,
ValueType op2_type,
ValueType mask_type,
int sew,
LMULInfo vlmul,
VReg result_reg_class,
VReg op1_reg_class,
VReg op2_reg_class> {
def : Pat<(result_type (vop
(op1_type op1_reg_class:$rs1),
(op2_type op2_reg_class:$rs2),
(result_type result_reg_class:$merge),
(mask_type V0),
(XLenVT timm:$roundmode),
VLOpFrag)),
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
result_reg_class:$merge,
op1_reg_class:$rs1,
op2_reg_class:$rs2,
(mask_type V0),
(XLenVT timm:$roundmode),
GPR:$vl, sew, TAIL_AGNOSTIC)>;
}

class VPatBinaryVL_V_RM<SDPatternOperator vop,
string instruction_name,
string suffix,
Expand Down Expand Up @@ -838,35 +812,6 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
xop_kind:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;

multiclass VPatBinaryRM_VL_XI<SDNode vop,
string instruction_name,
string suffix,
ValueType result_type,
ValueType vop1_type,
ValueType vop2_type,
ValueType mask_type,
int sew,
LMULInfo vlmul,
VReg result_reg_class,
VReg vop_reg_class,
ComplexPattern SplatPatKind,
DAGOperand xop_kind> {
def : Pat<(result_type (vop
(vop1_type vop_reg_class:$rs1),
(vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
(result_type result_reg_class:$merge),
(mask_type V0),
(XLenVT timm:$roundmode),
VLOpFrag)),
(!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
result_reg_class:$merge,
vop_reg_class:$rs1,
xop_kind:$rs2,
(mask_type V0),
(XLenVT timm:$roundmode),
GPR:$vl, sew, TAIL_AGNOSTIC)>;
}

multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
list<VTypeInfo> vtilist = AllIntegerVectors,
bit isSEWAware = 0> {
Expand Down Expand Up @@ -965,24 +910,6 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
}
}

multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
foreach VtiToWti = AllWidenableIntVectors in {
defvar vti = VtiToWti.Vti;
defvar wti = VtiToWti.Wti;
defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
!cast<ComplexPattern>(SplatPat#_#uimm5),
uimm5>;
}
}

class VPatBinaryVL_VF<SDPatternOperator vop,
string instruction_name,
ValueType result_type,
Expand Down Expand Up @@ -2468,8 +2395,26 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;

// 12.5. Vector Narrowing Fixed-Point Clip Instructions
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
// Rounding mode here is arbitrary since we aren't shifting out any bits.
def : Pat<(vti.Vector (riscv_trunc_vector_vl_ssat (wti.Vector wti.RegClass:$rs1),
(vti.Mask V0),
VLOpFrag)),
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
def : Pat<(vti.Vector (riscv_trunc_vector_vl_usat (wti.Vector wti.RegClass:$rs1),
(vti.Mask V0),
VLOpFrag)),
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
}
}

// 13. Vector Floating-Point Instructions

Expand Down

0 comments on commit 4237e3f

Please sign in to comment.