Skip to content

Commit

Permalink
[DAG] Support saturated truncate
Browse files Browse the repository at this point in the history
A truncate is considered saturated if no additional conversion is
required between the target and return values. If the target is
saturated when attempting to truncate from a vector, there is an
opportunity to optimize it.

Previously, each architecture had its own attempt at optimization,
leading to redundant code. This patch implements common logic by
introducing three new ISDs:

 `ISD::TRUNCATE_SSAT_S`: When the operand is a signed value and
 the range of values matches the range of signed values of the
 destination type.

 `ISD::TRUNCATE_SSAT_U`: When the operand is a signed value and
 the range of values matches the range of unsigned values of the
 destination type.

 `ISD::TRUNCATE_USAT_U`: When the operand is an unsigned value and
 the range of values matches the range of unsigned values of the
 destination type.

These ISDs indicate a saturated truncate.

Fixes #85903
  • Loading branch information
ParkHanbum committed Aug 9, 2024
1 parent f8006a5 commit 8d81896
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 1 deletion.
20 changes: 20 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,26 @@ enum NodeType {

/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
/// TRUNCATE_[SU]SAT_[SU] - Truncate for saturated operand
/// [SU] located in middle, prefix for `SAT` means indicates whether
/// existing truncate target was a signed operation. For examples,
/// If `truncate(smin(smax(x, C), C))` was saturated then become `S`.
/// If `truncate(umin(x, C))` was saturated then become `U`.
/// [SU] located in last indicates whether range of truncated values is
/// sign-saturated. For example, if `truncate(smin(smax(x, C), C))` is a
/// truncation to `i8`, then if value of C ranges from `-128 to 127`, it will
/// be saturated against signed values, resulting in `S`, which will combine
/// to `TRUNCATE_SSAT_S`. If the value of C ranges from `0 to 255`, it will
/// be saturated against unsigned values, resulting in `U`, which will
/// combine to `TRUNATE_SSAT_U`. Similarly, in `truncate(umin(x, C))`, if
/// value of C ranges from `0 to 255`, it becomes `U` because it is saturated
/// for unsigned values. As a result, it combines to `TRUNCATE_USAT_U`.
TRUNCATE_SSAT_S, // saturate signed input to signed result -
// truncate(smin(smax(x, C), C))
TRUNCATE_SSAT_U, // saturate signed input to unsigned result -
// truncate(smin(smax(x, 0), C))
TRUNCATE_USAT_U, // saturate unsigned input to unsigned result -
// truncate(umin(x, C))

/// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
/// depends on the first letter) to floating point.
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ def sext : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
def zext : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
def anyext : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
def trunc : SDNode<"ISD::TRUNCATE" , SDTIntTruncOp>;
def truncssat_s : SDNode<"ISD::TRUNCATE_SSAT_S", SDTIntTruncOp>;
def truncssat_u : SDNode<"ISD::TRUNCATE_SSAT_U", SDTIntTruncOp>;
def truncusat_u : SDNode<"ISD::TRUNCATE_USAT_U", SDTIntTruncOp>;
def bitconvert : SDNode<"ISD::BITCAST" , SDTUnaryOp>;
def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
def freeze : SDNode<"ISD::FREEZE" , SDTFreeze>;
Expand Down
112 changes: 111 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ namespace {
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
SDValue visitTRUNCATE(SDNode *N);
SDValue visitTRUNCATE_USAT(SDNode *N);
SDValue visitBITCAST(SDNode *N);
SDValue visitFREEZE(SDNode *N);
SDValue visitBUILD_PAIR(SDNode *N);
Expand Down Expand Up @@ -13203,7 +13204,9 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
unsigned CastOpcode = Cast->getOpcode();
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
CastOpcode == ISD::FP_ROUND) &&
CastOpcode == ISD::TRUNCATE_SSAT_S ||
CastOpcode == ISD::TRUNCATE_SSAT_U ||
CastOpcode == ISD::TRUNCATE_USAT_U || CastOpcode == ISD::FP_ROUND) &&
"Unexpected opcode for vector select narrowing/widening");

// We only do this transform before legal ops because the pattern may be
Expand Down Expand Up @@ -14915,6 +14918,109 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
return SDValue();
}

/// Detect patterns of truncation with unsigned saturation:
///
/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
/// Return the source value x to be truncated or SDValue() if the pattern was
/// not matched.
///
static SDValue detectUSatUPattern(SDValue In, EVT VT) {
unsigned NumDstBits = VT.getScalarSizeInBits();
unsigned NumSrcBits = In.getScalarValueSizeInBits();
// Saturation with truncation. We truncate from InVT to VT.
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");

SDValue Min;
APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
if (sd_match(In, m_UMin(m_Value(Min), m_SpecificInt(UnsignedMax))))
return Min;

return SDValue();
}

/// Detect patterns of truncation with signed saturation:
/// (truncate (smin (smax (x, signed_min_of_dest_type),
/// signed_max_of_dest_type)) to dest_type)
/// or:
/// (truncate (smax (smin (x, signed_max_of_dest_type),
/// signed_min_of_dest_type)) to dest_type).
///
/// Return the source value to be truncated or SDValue() if the pattern was not
/// matched.
static SDValue detectSSatSPattern(SDValue In, EVT VT) {
unsigned NumDstBits = VT.getScalarSizeInBits();
unsigned NumSrcBits = In.getScalarValueSizeInBits();
// Saturation with truncation. We truncate from InVT to VT.
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");

SDValue Val;
APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);

if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_SpecificInt(SignedMin)),
m_SpecificInt(SignedMax))))
return Val;

if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(SignedMax)),
m_SpecificInt(SignedMin))))
return Val;

return SDValue();
}

/// Detect patterns of truncation with unsigned saturation:
static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
const SDLoc &DL) {
unsigned NumDstBits = VT.getScalarSizeInBits();
unsigned NumSrcBits = In.getScalarValueSizeInBits();
// Saturation with truncation. We truncate from InVT to VT.
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");

SDValue Val;
APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
// Min == 0, Max is unsigned max of destination type.
if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(UnsignedMax)),
m_Zero())))
return Val;

if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_Zero()),
m_SpecificInt(UnsignedMax))))
return Val;

if (sd_match(In, m_UMin(m_SMax(m_Value(Val), m_Zero()),
m_SpecificInt(UnsignedMax))))
return Val;

return SDValue();
}

static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
SDLoc &DL, const TargetLowering &TLI,
SelectionDAG &DAG) {
auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
return (TLI.isOperationLegalOrCustom(Opc, SrcVT) &&
TLI.isTypeDesirableForOp(Opc, VT));
};

if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
if (SDValue SSatVal = detectSSatSPattern(Src, VT))
return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
} else if (Src.getOpcode() == ISD::UMIN) {
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
if (SDValue USatVal = detectUSatUPattern(Src, VT))
return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
}

return SDValue();
}

SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
Expand All @@ -14930,6 +15036,10 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
if (N0.getOpcode() == ISD::TRUNCATE)
return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));

// fold saturated truncate
if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG))
return SaturatedTR;

// fold (truncate c1) -> c1
if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
return C;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::SIGN_EXTEND_VECTOR_INREG: return "sign_extend_vector_inreg";
case ISD::ZERO_EXTEND_VECTOR_INREG: return "zero_extend_vector_inreg";
case ISD::TRUNCATE: return "truncate";
case ISD::TRUNCATE_SSAT_S: return "truncate_ssat_s";
case ISD::TRUNCATE_SSAT_U: return "truncate_ssat_u";
case ISD::TRUNCATE_USAT_U: return "truncate_usat_u";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
case ISD::FP_EXTEND: return "fp_extend";
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,11 @@ void TargetLoweringBase::initActions() {
// Absolute difference
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);

// Saturated trunc
setOperationAction(ISD::TRUNCATE_SSAT_S, VT, Expand);
setOperationAction(ISD::TRUNCATE_SSAT_U, VT, Expand);
setOperationAction(ISD::TRUNCATE_USAT_U, VT, Expand);

// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
Expand);
Expand Down

0 comments on commit 8d81896

Please sign in to comment.