From 108a26c9a7a771d23a7ada07172dccabf5e0a89c Mon Sep 17 00:00:00 2001 From: hanbeom Date: Tue, 16 Jul 2024 14:14:40 +0900 Subject: [PATCH] [RISCV] Support saturated truncate Add support for `ISD::TRUNCATE_[US]SAT`. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 40 +++++++++++++++------ llvm/lib/Target/RISCV/RISCVISelLowering.h | 2 ++ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 953196a586b6e4..3b54416d1b5b23 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -853,7 +853,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL" // nodes which truncate by one power of two at a time. - setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction( + {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, Custom); // Custom-lower insert/extract operations to simplify patterns. setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT, @@ -1168,7 +1169,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::SELECT, VT, Custom); - setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction( + {ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, + Custom); setOperationAction(ISD::BITCAST, VT, Custom); @@ -1479,8 +1482,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN}); if ((Subtarget.hasStdExtZbs() && Subtarget.is64Bit()) || - Subtarget.hasStdExtV()) + Subtarget.hasStdExtV()) { setTargetDAGCombine(ISD::TRUNCATE); + setTargetDAGCombine(ISD::TRUNCATE_SSAT); + setTargetDAGCombine(ISD::TRUNCATE_USAT); + } if (Subtarget.hasStdExtZbkb()) setTargetDAGCombine(ISD::BITREVERSE); @@ -6092,7 +6098,7 @@ static bool hasMergeOp(unsigned Opcode) { Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE && "not a RISC-V target specific op"); static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == - 130 && + 132 && RISCVISD::LAST_RISCV_STRICTFP_OPCODE - ISD::FIRST_TARGET_STRICTFP_OPCODE == 21 && @@ -6118,7 +6124,7 @@ static bool hasMaskOp(unsigned Opcode) { Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE && "not a RISC-V target specific op"); static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == - 130 && + 132 && RISCVISD::LAST_RISCV_STRICTFP_OPCODE - ISD::FIRST_TARGET_STRICTFP_OPCODE == 21 && @@ -6389,6 +6395,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap); } case ISD::TRUNCATE: + case ISD::TRUNCATE_SSAT: + case ISD::TRUNCATE_USAT: // Only custom-lower vector truncates if (!Op.getSimpleValueType().isVector()) return Op; @@ -8275,11 +8283,15 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op, LLVMContext &Context = *DAG.getContext(); const ElementCount Count = ContainerVT.getVectorElementCount(); + unsigned NewOpc = RISCVISD::TRUNCATE_VECTOR_VL; + if (Op.getOpcode() == ISD::TRUNCATE_SSAT) + NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT; + else if (Op.getOpcode() == ISD::TRUNCATE_USAT) + NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT; do { SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2); EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count); - Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result, - Mask, VL); + Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL); } while (SrcEltVT != DstEltVT); if (SrcVT.isFixedLengthVector()) @@ -16512,7 +16524,9 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) { // minimum value. static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL); + assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL || + N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT || + N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT); MVT VT = N->getSimpleValueType(0); @@ -16617,9 +16631,11 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG, SDValue Val; unsigned ClipOpc; - if ((Val = DetectUSatPattern(Src))) + + Val = N->getOperand(0); + if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT) ClipOpc = RISCVISD::VNCLIPU_VL; - else if ((Val = DetectSSatPattern(Src))) + else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT) ClipOpc = RISCVISD::VNCLIP_VL; else return SDValue(); @@ -16857,6 +16873,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } return SDValue(); case RISCVISD::TRUNCATE_VECTOR_VL: + case RISCVISD::TRUNCATE_VECTOR_VL_SSAT: + case RISCVISD::TRUNCATE_VECTOR_VL_USAT: if (SDValue V = combineTruncOfSraSext(N, DAG)) return V; return combineTruncToVnclip(N, DAG, Subtarget); @@ -20433,6 +20451,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL) NODE_NAME_CASE(READ_VLENB) NODE_NAME_CASE(TRUNCATE_VECTOR_VL) + NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT) + NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT) NODE_NAME_CASE(VSLIDEUP_VL) NODE_NAME_CASE(VSLIDE1UP_VL) NODE_NAME_CASE(VSLIDEDOWN_VL) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 0b0ad9229f0b35..3d582fcdaf64bb 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -181,6 +181,8 @@ enum NodeType : unsigned { // Truncates a RVV integer vector by one power-of-two. Carries both an extra // mask and VL operand. TRUNCATE_VECTOR_VL, + TRUNCATE_VECTOR_VL_SSAT, + TRUNCATE_VECTOR_VL_USAT, // Matches the semantics of vslideup/vslidedown. The first operand is the // pass-thru operand, the second is the source vector, the third is the XLenVT // index (either constant or non-constant), the fourth is the mask, the fifth