Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Allow intrinsics with aggregate return type to reach GlobalISel #108893

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ SPIR-V backend, along with their descriptions and argument details.
- None
- `[Type, Vararg]`
- Assigns names to types or values, enhancing readability and debuggability of SPIR-V code. Not emitted directly but used for metadata enrichment.
* - `int_spv_value_md`
- None
- `[Metadata]`
- Assigns a set of attributes (such as name and data type) to a value that is the argument of the associated `llvm.fake.use` intrinsic call. The latter is used as a mean to map virtual registers created by IRTranslator to the original value.
* - `int_spv_assign_decoration`
- None
- `[Type, Metadata]`
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ let TargetPrefix = "spv" in {
def int_spv_assign_ptr_type : Intrinsic<[], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
def int_spv_assign_name : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
def int_spv_assign_decoration : Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>;
def int_spv_value_md : Intrinsic<[], [llvm_metadata_ty]>;

def int_spv_track_constant : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty]>;
def int_spv_init_global : Intrinsic<[], [llvm_any_ty, llvm_any_ty]>;
Expand Down
35 changes: 32 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ bool isConvergenceIntrinsic(const Instruction *I) {
II->getIntrinsicID() == Intrinsic::experimental_convergence_loop ||
II->getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
}

bool allowEmitFakeUse(const Value *Arg) {
if (const auto *II = dyn_cast<IntrinsicInst>(Arg))
if (Function *F = II->getCalledFunction())
if (F->getName().starts_with("llvm.spv."))
return false;
if (dyn_cast<AtomicCmpXchgInst>(Arg) || dyn_cast<InsertValueInst>(Arg) ||
dyn_cast<UndefValue>(Arg))
return false;
if (const auto *LI = dyn_cast<LoadInst>(Arg))
if (LI->getType()->isAggregateType())
return false;
return true;
}

} // namespace

char SPIRVEmitIntrinsics::ID = 0;
Expand Down Expand Up @@ -283,8 +298,20 @@ static inline Type *reconstructType(SPIRVGlobalRegistry *GR, Value *Op) {
void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
Value *Arg) {
Value *OfType = PoisonValue::get(Ty);
CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type,
{Arg->getType()}, OfType, Arg, {}, B);
CallInst *AssignCI = nullptr;
if (Arg->getType()->isAggregateType() && Ty->isAggregateType() &&
allowEmitFakeUse(Arg)) {
LLVMContext &Ctx = Arg->getContext();
SmallVector<Metadata *, 2> ArgMDs{
MDNode::get(Ctx, ValueAsMetadata::getConstant(OfType)),
MDString::get(Ctx, Arg->getName())};
B.CreateIntrinsic(Intrinsic::spv_value_md, {},
{MetadataAsValue::get(Ctx, MDTuple::get(Ctx, ArgMDs))});
AssignCI = B.CreateIntrinsic(Intrinsic::fake_use, {}, {Arg});
} else {
AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type, {Arg->getType()},
OfType, Arg, {}, B);
}
GR->addAssignPtrTypeInstr(Arg, AssignCI);
}

Expand Down Expand Up @@ -1268,6 +1295,8 @@ Instruction *SPIRVEmitIntrinsics::visitInsertValueInst(InsertValueInst &I) {
}

Instruction *SPIRVEmitIntrinsics::visitExtractValueInst(ExtractValueInst &I) {
if (I.getAggregateOperand()->getType()->isAggregateType())
return &I;
s-perron marked this conversation as resolved.
Show resolved Hide resolved
IRBuilder<> B(I.getParent());
B.SetInsertPoint(&I);
SmallVector<Value *> Args;
Expand Down Expand Up @@ -1533,7 +1562,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
I->setOperand(OpNo, NewOp);
}
}
if (I->hasName()) {
if (I->hasName() && !I->getType()->isAggregateType()) {
reportFatalOnTokenType(I);
setInsertPointAfterDef(B, I);
std::vector<Value *> Args = {I};
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class SPIRVGlobalRegistry {
// created during substitution of aggregate arguments
// (see `SPIRVPrepareFunctions::removeAggregateTypesFromSignature()`)
DenseMap<Value *, Type *> MutatedAggRet;
// map an instruction to its value's attributes (type, name)
DenseMap<MachineInstr *, std::pair<Type *, std::string>> ValueAttrs;

// Look for an equivalent of the newType in the map. Return the equivalent
// if it's found, otherwise insert newType to the map and return the type.
Expand Down Expand Up @@ -177,6 +179,21 @@ class SPIRVGlobalRegistry {
return It == MutatedAggRet.end() ? nullptr : It->second;
}

// A registry of value's attributes (type, name)
// - Add a record.
void addValueAttrs(MachineInstr *Key, std::pair<Type *, std::string> Val) {
ValueAttrs[Key] = Val;
}
// - Find a record.
bool findValueAttrs(const MachineInstr *Key, Type *&Ty, StringRef &Name) {
auto It = ValueAttrs.find(Key);
if (It == ValueAttrs.end())
return false;
Ty = It->second.first;
Name = It->second.second;
return true;
}

// Deduced element types of untyped pointers and composites:
// - Add a record to the map of deduced element types.
void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ def OpMatrixTimesMatrix: BinOp<"OpMatrixTimesMatrix", 146>;
def OpOuterProduct: BinOp<"OpOuterProduct", 147>;
def OpDot: BinOp<"OpDot", 148>;

def OpIAddCarry: BinOpTyped<"OpIAddCarry", 149, iID, addc>;
def OpISubBorrow: BinOpTyped<"OpISubBorrow", 150, iID, subc>;
defm OpIAddCarry: BinOpTypedGen<"OpIAddCarry", 149, addc, 0, 1>;
defm OpISubBorrow: BinOpTypedGen<"OpISubBorrow", 150, subc, 0, 1>;
def OpUMulExtended: BinOp<"OpUMulExtended", 151>;
def OpSMulExtended: BinOp<"OpSMulExtended", 152>;

Expand Down
99 changes: 98 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectFloatDot(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectOverflowArith(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, unsigned Opcode) const;

bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

Expand Down Expand Up @@ -409,11 +412,22 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
return false;
}

static bool mayApplyGenericSelection(unsigned Opcode) {
switch (Opcode) {
case TargetOpcode::G_CONSTANT:
return false;
case TargetOpcode::G_SADDO:
case TargetOpcode::G_SSUBO:
return true;
}
return isTypeFoldingSupported(Opcode);
}

bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
const unsigned Opcode = I.getOpcode();
if (isTypeFoldingSupported(Opcode) && Opcode != TargetOpcode::G_CONSTANT)
if (mayApplyGenericSelection(Opcode))
return selectImpl(I, *CoverageInfo);
switch (Opcode) {
case TargetOpcode::G_CONSTANT:
Expand Down Expand Up @@ -590,6 +604,21 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_USUBSAT:
return selectExtInst(ResVReg, ResType, I, CL::u_sub_sat);

case TargetOpcode::G_UADDO:
return selectOverflowArith(ResVReg, ResType, I,
ResType->getOpcode() == SPIRV::OpTypeVector
? SPIRV::OpIAddCarryV
: SPIRV::OpIAddCarryS);
case TargetOpcode::G_USUBO:
return selectOverflowArith(ResVReg, ResType, I,
ResType->getOpcode() == SPIRV::OpTypeVector
? SPIRV::OpISubBorrowV
: SPIRV::OpISubBorrowS);
case TargetOpcode::G_UMULO:
return selectOverflowArith(ResVReg, ResType, I, SPIRV::OpUMulExtended);
case TargetOpcode::G_SMULO:
return selectOverflowArith(ResVReg, ResType, I, SPIRV::OpSMulExtended);

case TargetOpcode::G_SEXT:
return selectExt(ResVReg, ResType, I, true);
case TargetOpcode::G_ANYEXT:
Expand Down Expand Up @@ -1101,6 +1130,71 @@ bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectOverflowArith(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
unsigned Opcode) const {
Type *ResTy = nullptr;
StringRef ResName;
if (!GR.findValueAttrs(&I, ResTy, ResName))
report_fatal_error(
"Not enough info to select the arithmetic with overflow instruction");
if (!ResTy || !ResTy->isStructTy())
report_fatal_error("Expect struct type result for the arithmetic "
"with overflow instruction");
// "Result Type must be from OpTypeStruct. The struct must have two members,
// and the two members must be the same type."
Type *ResElemTy = cast<StructType>(ResTy)->getElementType(0);
ResTy = StructType::create(SmallVector<Type *, 2>{ResElemTy, ResElemTy});
// Build SPIR-V types and constant(s) if needed.
MachineIRBuilder MIRBuilder(I);
SPIRVType *StructType = GR.getOrCreateSPIRVType(
ResTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false);
assert(I.getNumDefs() > 1 && "Not enought operands");
SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII);
unsigned N = GR.getScalarOrVectorComponentCount(ResType);
if (N > 1)
BoolType = GR.getOrCreateSPIRVVectorType(BoolType, N, I, TII);
Register BoolTypeReg = GR.getSPIRVTypeID(BoolType);
Register ZeroReg = buildZerosVal(ResType, I);
// A new virtual register to store the result struct.
Register StructVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(StructVReg, &SPIRV::IDRegClass);
// Build the result name if needed.
if (ResName.size() > 0)
buildOpName(StructVReg, ResName, MIRBuilder);
// Build the arithmetic with overflow instruction.
MachineBasicBlock &BB = *I.getParent();
auto MIB =
BuildMI(BB, MIRBuilder.getInsertPt(), I.getDebugLoc(), TII.get(Opcode))
.addDef(StructVReg)
.addUse(GR.getSPIRVTypeID(StructType));
for (unsigned i = I.getNumDefs(); i < I.getNumOperands(); ++i)
MIB.addUse(I.getOperand(i).getReg());
bool Status = MIB.constrainAllUses(TII, TRI, RBI);
// Build instructions to extract fields of the instruction's result.
// A new virtual register to store the higher part of the result struct.
Register HigherVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(HigherVReg, &SPIRV::iIDRegClass);
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
auto MIB =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(i == 1 ? HigherVReg : I.getOperand(i).getReg())
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(StructVReg)
.addImm(i);
Status &= MIB.constrainAllUses(TII, TRI, RBI);
}
// Build boolean value from the higher part.
Status &= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpINotEqual))
.addDef(I.getOperand(1).getReg())
.addUse(BoolTypeReg)
.addUse(HigherVReg)
.addUse(ZeroReg)
.constrainAllUses(TII, TRI, RBI);
return Status;
}

bool SPIRVInstructionSelector::selectAtomicCmpXchg(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2505,6 +2599,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
}
case Intrinsic::spv_step:
return selectStep(ResVReg, ResType, I);
case Intrinsic::spv_value_md:
// ignore the intrinsic
break;
default: {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
// TODO: add proper legalization rules.
getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();

getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
getActionDefinitionsBuilder(
{G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
.alwaysLegal();

// FP conversions.
Expand Down
27 changes: 26 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,13 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
.addUse(NewReg)
.addUse(GR->getSPIRVTypeID(SpvType))
.setMIFlags(Flags);
Def->getOperand(0).setReg(NewReg);
for (unsigned I = 0, E = Def->getNumDefs(); I != E; ++I) {
MachineOperand &MO = Def->getOperand(I);
if (MO.getReg() == Reg) {
MO.setReg(NewReg);
break;
}
}
return NewReg;
}

Expand Down Expand Up @@ -462,6 +468,25 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Def->getOpcode() != SPIRV::ASSIGN_TYPE)
insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MF.getRegInfo());
ToErase.push_back(&MI);
} else if (MIOp == TargetOpcode::FAKE_USE && MI.getNumOperands() > 0) {
MachineInstr *MdMI = MI.getPrevNode();
VyacheslavLevytskyy marked this conversation as resolved.
Show resolved Hide resolved
if (MdMI && isSpvIntrinsic(*MdMI, Intrinsic::spv_value_md)) {
// It's an internal service info from before IRTranslator passes.
MachineInstr *Def = getVRegDef(MRI, MI.getOperand(0).getReg());
for (unsigned I = 1, E = MI.getNumOperands(); I != E && Def; ++I)
if (getVRegDef(MRI, MI.getOperand(I).getReg()) != Def)
Def = nullptr;
if (Def) {
const MDNode *MD = MdMI->getOperand(1).getMetadata();
StringRef ValueName =
cast<MDString>(MD->getOperand(1))->getString();
const MDNode *TypeMD = cast<MDNode>(MD->getOperand(0));
Type *ValueTy = getMDOperandAsType(TypeMD, 0);
GR->addValueAttrs(Def, std::make_pair(ValueTy, ValueName.str()));
}
ToErase.push_back(MdMI);
}
ToErase.push_back(&MI);
} else if (MIOp == TargetOpcode::G_CONSTANT ||
MIOp == TargetOpcode::G_FCONSTANT ||
MIOp == TargetOpcode::G_BUILD_VECTOR) {
Expand Down
48 changes: 5 additions & 43 deletions llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,30 +342,6 @@ static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
FSHIntrinsic->setCalledFunction(FSHFunc);
}

static void buildUMulWithOverflowFunc(Function *UMulFunc) {
// The function body is already created.
if (!UMulFunc->empty())
return;

BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
"entry", UMulFunc);
IRBuilder<> IRB(EntryBB);
// Build the actual unsigned multiplication logic with the overflow
// indication. Do unsigned multiplication Mul = A * B. Then check
// if unsigned division Div = Mul / A is not equal to B. If so,
// then overflow has happened.
Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);

// umul.with.overflow intrinsic return a structure, where the first element
// is the multiplication result, and the second is an overflow bit.
Type *StructTy = UMulFunc->getReturnType();
Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
IRB.CreateRet(Res);
}

static void lowerExpectAssume(IntrinsicInst *II) {
// If we cannot use the SPV_KHR_expect_assume extension, then we need to
// ignore the intrinsic and move on. It should be removed later on by LLVM.
Expand Down Expand Up @@ -407,20 +383,6 @@ static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
return true;
}

static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
// Get a separate function - otherwise, we'd have to rework the CFG of the
// current one. Then simply replace the intrinsic uses with a call to the new
// function.
Module *M = UMulIntrinsic->getModule();
FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
Type *FSHLRetTy = UMulFuncTy->getReturnType();
const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
Function *UMulFunc =
getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
buildUMulWithOverflowFunc(UMulFunc);
UMulIntrinsic->setCalledFunction(UMulFunc);
}

// Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
// or calls to proper generated functions. Returns True if F was modified.
bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
Expand All @@ -444,10 +406,6 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
lowerFunnelShifts(II);
Changed = true;
break;
case Intrinsic::umul_with_overflow:
lowerUMulWithOverflow(II);
Changed = true;
break;
case Intrinsic::assume:
case Intrinsic::expect: {
const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
Expand Down Expand Up @@ -478,9 +436,13 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
// noted in 'spv.cloned_funcs' metadata for later restoration.
Function *
SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
bool IsRetAggr = F->getReturnType()->isAggregateType();
// Allow intrinsics with aggregate return type to reach GlobalISel
if (F->isIntrinsic() && IsRetAggr)
return F;

IRBuilder<> B(F->getContext());

bool IsRetAggr = F->getReturnType()->isAggregateType();
bool HasAggrArg =
std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
return Arg.getType()->isAggregateType();
Expand Down
Loading