Skip to content

Commit

Permalink
[RISCV] Support saturated truncate
Browse files Browse the repository at this point in the history
Add support for `ISD::TRUNCATE_[US]SAT`.
  • Loading branch information
ParkHanbum committed Jul 17, 2024
1 parent c80af03 commit 108a26c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
40 changes: 30 additions & 10 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
// nodes which truncate by one power of two at a time.
setOperationAction(ISD::TRUNCATE, VT, Custom);
setOperationAction(
{ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT, Custom);

// Custom-lower insert/extract operations to simplify patterns.
setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT,
Expand Down Expand Up @@ -1168,7 +1169,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

setOperationAction(ISD::SELECT, VT, Custom);

setOperationAction(ISD::TRUNCATE, VT, Custom);
setOperationAction(
{ISD::TRUNCATE, ISD::TRUNCATE_SSAT, ISD::TRUNCATE_USAT}, VT,
Custom);

setOperationAction(ISD::BITCAST, VT, Custom);

Expand Down Expand Up @@ -1479,8 +1482,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});

if ((Subtarget.hasStdExtZbs() && Subtarget.is64Bit()) ||
Subtarget.hasStdExtV())
Subtarget.hasStdExtV()) {
setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::TRUNCATE_SSAT);
setTargetDAGCombine(ISD::TRUNCATE_USAT);
}

if (Subtarget.hasStdExtZbkb())
setTargetDAGCombine(ISD::BITREVERSE);
Expand Down Expand Up @@ -6092,7 +6098,7 @@ static bool hasMergeOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
130 &&
132 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
Expand All @@ -6118,7 +6124,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
130 &&
132 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
Expand Down Expand Up @@ -6389,6 +6395,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap);
}
case ISD::TRUNCATE:
case ISD::TRUNCATE_SSAT:
case ISD::TRUNCATE_USAT:
// Only custom-lower vector truncates
if (!Op.getSimpleValueType().isVector())
return Op;
Expand Down Expand Up @@ -8275,11 +8283,15 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,

LLVMContext &Context = *DAG.getContext();
const ElementCount Count = ContainerVT.getVectorElementCount();
unsigned NewOpc = RISCVISD::TRUNCATE_VECTOR_VL;
if (Op.getOpcode() == ISD::TRUNCATE_SSAT)
NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
else if (Op.getOpcode() == ISD::TRUNCATE_USAT)
NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
do {
SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2);
EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count);
Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result,
Mask, VL);
Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL);
} while (SrcEltVT != DstEltVT);

if (SrcVT.isFixedLengthVector())
Expand Down Expand Up @@ -16512,7 +16524,9 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
// minimum value.
static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL ||
N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT ||
N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT);

MVT VT = N->getSimpleValueType(0);

Expand Down Expand Up @@ -16617,9 +16631,11 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,

SDValue Val;
unsigned ClipOpc;
if ((Val = DetectUSatPattern(Src)))

Val = N->getOperand(0);
if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_USAT)
ClipOpc = RISCVISD::VNCLIPU_VL;
else if ((Val = DetectSSatPattern(Src)))
else if (N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL_SSAT)
ClipOpc = RISCVISD::VNCLIP_VL;
else
return SDValue();
Expand Down Expand Up @@ -16857,6 +16873,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
return SDValue();
case RISCVISD::TRUNCATE_VECTOR_VL:
case RISCVISD::TRUNCATE_VECTOR_VL_SSAT:
case RISCVISD::TRUNCATE_VECTOR_VL_USAT:
if (SDValue V = combineTruncOfSraSext(N, DAG))
return V;
return combineTruncToVnclip(N, DAG, Subtarget);
Expand Down Expand Up @@ -20433,6 +20451,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDE1UP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ enum NodeType : unsigned {
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
// mask and VL operand.
TRUNCATE_VECTOR_VL,
TRUNCATE_VECTOR_VL_SSAT,
TRUNCATE_VECTOR_VL_USAT,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the XLenVT
// index (either constant or non-constant), the fourth is the mask, the fifth
Expand Down

0 comments on commit 108a26c

Please sign in to comment.