Skip to content

Commit

Permalink
[LV] Support generating masks for switch terminators. (llvm#99808)
Browse files Browse the repository at this point in the history
Update createEdgeMask to created masks where the terminator in Src is a
switch. We need to handle 2 separate cases:

1. Dst is not the default desintation. Dst is reached if any of the
cases with destination == Dst are taken. Join the conditions for each
case where destination == Dst using a logical OR.
2. Dst is the default destination. Dst is reached if none of the cases
with destination != Dst are taken. Join the conditions for each case
where the destination is != Dst using a logical OR and negate it.

Edge masks are created for every destination of cases and/or 
default when requesting a mask where the source is a switch.

Fixes llvm#48188.

PR: llvm#99808
  • Loading branch information
fhahn authored and bwendling committed Aug 15, 2024
1 parent 7866fad commit 52fd95c
Show file tree
Hide file tree
Showing 8 changed files with 1,349 additions and 52 deletions.
21 changes: 15 additions & 6 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1340,12 +1340,21 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {

// Collect the blocks that need predication.
for (BasicBlock *BB : TheLoop->blocks()) {
// We don't support switch statements inside loops.
if (!isa<BranchInst>(BB->getTerminator())) {
reportVectorizationFailure("Loop contains a switch statement",
"loop contains a switch statement",
"LoopContainsSwitch", ORE, TheLoop,
BB->getTerminator());
// We support only branches and switch statements as terminators inside the
// loop.
if (isa<SwitchInst>(BB->getTerminator())) {
if (TheLoop->isLoopExiting(BB)) {
reportVectorizationFailure("Loop contains an unsupported switch",
"loop contains an unsupported switch",
"LoopContainsUnsupportedSwitch", ORE,
TheLoop, BB->getTerminator());
return false;
}
} else if (!isa<BranchInst>(BB->getTerminator())) {
reportVectorizationFailure("Loop contains an unsupported terminator",
"loop contains an unsupported terminator",
"LoopContainsUnsupportedTerminator", ORE,
TheLoop, BB->getTerminator());
return false;
}

Expand Down
74 changes: 73 additions & 1 deletion llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6453,6 +6453,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
// a predicated block since it will become a fall-through, although we
// may decide in the future to call TTI for all branches.
}
case Instruction::Switch: {
if (VF.isScalar())
return TTI.getCFInstrCost(Instruction::Switch, CostKind);
auto *Switch = cast<SwitchInst>(I);
return Switch->getNumCases() *
TTI.getCmpSelInstrCost(
Instruction::ICmp,
ToVectorTy(Switch->getCondition()->getType(), VF),
ToVectorTy(Type::getInt1Ty(I->getContext()), VF),
CmpInst::ICMP_EQ, CostKind);
}
case Instruction::PHI: {
auto *Phi = cast<PHINode>(I);

Expand Down Expand Up @@ -7841,6 +7852,62 @@ VPRecipeBuilder::mapToVPValues(User::op_range Operands) {
return map_range(Operands, Fn);
}

void VPRecipeBuilder::createSwitchEdgeMasks(SwitchInst *SI) {
BasicBlock *Src = SI->getParent();
assert(!OrigLoop->isLoopExiting(Src) &&
all_of(successors(Src),
[this](BasicBlock *Succ) {
return OrigLoop->getHeader() != Succ;
}) &&
"unsupported switch either exiting loop or continuing to header");
// Create masks where the terminator in Src is a switch. We create mask for
// all edges at the same time. This is more efficient, as we can create and
// collect compares for all cases once.
VPValue *Cond = getVPValueOrAddLiveIn(SI->getCondition(), Plan);
BasicBlock *DefaultDst = SI->getDefaultDest();
MapVector<BasicBlock *, SmallVector<VPValue *>> Dst2Compares;
for (auto &C : SI->cases()) {
BasicBlock *Dst = C.getCaseSuccessor();
assert(!EdgeMaskCache.contains({Src, Dst}) && "Edge masks already created");
// Cases whose destination is the same as default are redundant and can be
// ignored - they will get there anyhow.
if (Dst == DefaultDst)
continue;
auto I = Dst2Compares.insert({Dst, {}});
VPValue *V = getVPValueOrAddLiveIn(C.getCaseValue(), Plan);
I.first->second.push_back(Builder.createICmp(CmpInst::ICMP_EQ, Cond, V));
}

// We need to handle 2 separate cases below for all entries in Dst2Compares,
// which excludes destinations matching the default destination.
VPValue *SrcMask = getBlockInMask(Src);
VPValue *DefaultMask = nullptr;
for (const auto &[Dst, Conds] : Dst2Compares) {
// 1. Dst is not the default destination. Dst is reached if any of the cases
// with destination == Dst are taken. Join the conditions for each case
// whose destination == Dst using an OR.
VPValue *Mask = Conds[0];
for (VPValue *V : ArrayRef<VPValue *>(Conds).drop_front())
Mask = Builder.createOr(Mask, V);
if (SrcMask)
Mask = Builder.createLogicalAnd(SrcMask, Mask);
EdgeMaskCache[{Src, Dst}] = Mask;

// 2. Create the mask for the default destination, which is reached if none
// of the cases with destination != default destination are taken. Join the
// conditions for each case where the destination is != Dst using an OR and
// negate it.
DefaultMask = DefaultMask ? Builder.createOr(DefaultMask, Mask) : Mask;
}

if (DefaultMask) {
DefaultMask = Builder.createNot(DefaultMask);
if (SrcMask)
DefaultMask = Builder.createLogicalAnd(SrcMask, DefaultMask);
}
EdgeMaskCache[{Src, DefaultDst}] = DefaultMask;
}

VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
assert(is_contained(predecessors(Dst), Src) && "Invalid edge");

Expand All @@ -7850,12 +7917,17 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
if (ECEntryIt != EdgeMaskCache.end())
return ECEntryIt->second;

if (auto *SI = dyn_cast<SwitchInst>(Src->getTerminator())) {
createSwitchEdgeMasks(SI);
assert(EdgeMaskCache.contains(Edge) && "Mask for Edge not created?");
return EdgeMaskCache[Edge];
}

VPValue *SrcMask = getBlockInMask(Src);

// The terminator has to be a branch inst!
BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator());
assert(BI && "Unexpected terminator found");

if (!BI->isConditional() || BI->getSuccessor(0) == BI->getSuccessor(1))
return EdgeMaskCache[Edge] = SrcMask;

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class VPRecipeBuilder {
/// Returns the *entry* mask for the block \p BB.
VPValue *getBlockInMask(BasicBlock *BB) const;

/// Create an edge mask for every destination of cases and/or default.
void createSwitchEdgeMasks(SwitchInst *SI);

/// A helper function that computes the predicate of the edge between SRC
/// and DST.
VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst);
Expand Down
Loading

0 comments on commit 52fd95c

Please sign in to comment.