Skip to content

Commit

Permalink
[SPIR-V] Rework usage of virtual registers' types and classes (llvm#1…
Browse files Browse the repository at this point in the history
…04104)

This PR continues llvm#101732
changes in virtual register processing aimed to improve correctness of
emitted MIR between passes from the perspective of MachineVerifier.
Namely, the following changes are introduced:
* register classes (lib/Target/SPIRV/SPIRVRegisterInfo.td) and
instruction patterns (lib/Target/SPIRV/SPIRVInstrInfo.td) are corrected
and simplified (by removing unnecessary sophisticated options) -- e.g.,
this PR gets rid of duplicating 32/64 bits patterns, removes ANYID
register class and simplifies definition of the rest of register
classes,
* hardcoded LLT scalar types in passes before instruction selection are
corrected -- the goal is to have correct bit width before instruction
selection, and use 64 bits registers for pattern matching in the
instruction selection pass; 32-bit registers remain where they are
described in such terms by SPIR-V specification (like, for example,
creation of virtual registers for scope/mem semantics operands),
* rework virtual register type/class assignment for calls/builtins
lowering,
* a series of minor changes to fix validity of emitted code between
passes:
  - ensure that that bitcast changes the type,
  - fix the pattern for instruction selection for OpExtInst,
  - simplify inline asm operands usage,
  - account for arbitrary integer sizes / update legalizer rules;
* add '-verify-machineinstrs' to existed test cases.

See also llvm#88129 that this PR
may resolve.

This PR fixes a great number of issues reported by MachineVerifier and,
as a result, reduces a number of failed test cases for the mode with
expensive checks set on from ~200 to ~57.
  • Loading branch information
VyacheslavLevytskyy authored and cjdb committed Aug 23, 2024
1 parent 442b6ad commit 5bf3c4d
Show file tree
Hide file tree
Showing 177 changed files with 647 additions and 669 deletions.
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
// Check if we define an ID, and take a type as operand 1.
auto &DefOpInfo = MCDesc.operands()[0];
auto &FirstArgOpInfo = MCDesc.operands()[1];
return DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
}
return false;
Expand Down
254 changes: 82 additions & 172 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Large diffs are not rendered by default.

27 changes: 18 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
}

auto MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
Expand Down Expand Up @@ -403,12 +403,14 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
int i = 0;
for (const auto &Arg : F.args()) {
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
MRI->setRegClass(VRegs[i][0], &SPIRV::iIDRegClass);
Register ArgReg = VRegs[i][0];
MRI->setRegClass(ArgReg, GR->getRegClass(ArgTypeVRegs[i]));
MRI->setType(ArgReg, GR->getRegType(ArgTypeVRegs[i]));
MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
.addDef(VRegs[i][0])
.addDef(ArgReg)
.addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
if (F.isDeclaration())
GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
GR->add(&Arg, &MIRBuilder.getMF(), ArgReg);
i++;
}
// Name the function.
Expand Down Expand Up @@ -532,10 +534,17 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
ArgVRegs.push_back(Arg.Regs[0]);
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
Register ArgReg = Arg.Regs[0];
ArgVRegs.push_back(ArgReg);
SPIRVType *SpvType = GR->getSPIRVTypeForVReg(ArgReg);
if (!SpvType) {
SpvType = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
}
if (!MRI->getRegClassOrNull(ArgReg)) {
MRI->setRegClass(ArgReg, GR->getRegClass(SpvType));
MRI->setType(ArgReg, GR->getRegType(SpvType));
}
}
auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
: SPIRV::InstructionSet::GLSL_std_450;
Expand All @@ -557,7 +566,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
for (const Argument &Arg : CF->args()) {
if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
continue; // Don't handle zero sized types.
Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(Reg, &SPIRV::iIDRegClass);
ToInsert.push_back({Reg});
VRegArgs.push_back(ToInsert.back());
Expand Down
155 changes: 100 additions & 55 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,14 @@ void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
VRegToTypeMap[&MF][VReg] = SpirvType;
}

static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
auto &MRI = MIRBuilder.getMF().getRegInfo();
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
static Register createTypeVReg(MachineRegisterInfo &MRI) {
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(64));
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
return Res;
}

static Register createTypeVReg(MachineRegisterInfo &MRI) {
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
return Res;
inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
return createTypeVReg(MIRBuilder.getMF().getRegInfo());
}

SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
Expand Down Expand Up @@ -157,26 +154,24 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
return MIB;
}

std::tuple<Register, ConstantInt *, bool>
std::tuple<Register, ConstantInt *, bool, unsigned>
SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
const IntegerType *LLVMIntTy;
if (SpvType)
LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
else
LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
assert(SpvType);
const IntegerType *LLVMIntTy =
cast<IntegerType>(getTypeForSPIRVType(SpvType));
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
bool NewInstr = false;
// Find a constant in DT or build a new one.
ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
Res =
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
Expand All @@ -185,35 +180,27 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
DT.add(CI, CurMF, Res);
NewInstr = true;
}
return std::make_tuple(Res, CI, NewInstr);
return std::make_tuple(Res, CI, NewInstr, BitWidth);
}

std::tuple<Register, ConstantFP *, bool, unsigned>
SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
const Type *LLVMFloatTy;
assert(SpvType);
LLVMContext &Ctx = CurMF->getFunction().getContext();
unsigned BitWidth = 32;
if (SpvType)
LLVMFloatTy = getTypeForSPIRVType(SpvType);
else {
LLVMFloatTy = Type::getFloatTy(Ctx);
if (MIRBuilder)
SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
}
const Type *LLVMFloatTy = getTypeForSPIRVType(SpvType);
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
bool NewInstr = false;
// Find a constant in DT or build a new one.
auto *const CI = ConstantFP::get(Ctx, Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
if (SpvType)
BitWidth = getScalarOrVectorBitWidth(SpvType);
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
Res =
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
Expand Down Expand Up @@ -269,7 +256,8 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
ConstantInt *CI;
Register Res;
bool New;
std::tie(Res, CI, New) =
unsigned BitWidth;
std::tie(Res, CI, New, BitWidth) =
getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
// If we have found Res register which is defined by the passed G_CONSTANT
// machine instruction, a new constant instruction should be created.
Expand All @@ -281,7 +269,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(Res)
Expand All @@ -297,19 +285,17 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType,
bool EmitIR) {
assert(SpvType);
auto &MF = MIRBuilder.getMF();
const IntegerType *LLVMIntTy;
if (SpvType)
LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
else
LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
const IntegerType *LLVMIntTy =
cast<IntegerType>(getTypeForSPIRVType(SpvType));
// Find a constant in DT or build a new one.
const auto ConstInt =
ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(ConstInt, &MF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
LLT LLTy = LLT::scalar(BitWidth);
Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
MF.getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
Expand All @@ -318,18 +304,17 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
if (EmitIR) {
MIRBuilder.buildConstant(Res, *ConstInt);
} else {
if (!SpvType)
SpvType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
Register SpvTypeReg = getSPIRVTypeID(SpvType);
MachineInstrBuilder MIB;
if (Val) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
.addUse(SpvTypeReg);
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
.addUse(SpvTypeReg);
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
Expand All @@ -353,7 +338,8 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
const auto ConstFP = ConstantFP::get(Ctx, Val);
Register Res = DT.find(ConstFP, &MF);
if (!Res.isValid()) {
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
Res = MF.getRegInfo().createGenericVirtualRegister(
LLT::scalar(getScalarOrVectorBitWidth(SpvType)));
MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, MF);
DT.add(ConstFP, &MF, Res);
Expand Down Expand Up @@ -407,7 +393,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(

// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
LLT LLTy = LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
Expand Down Expand Up @@ -509,7 +495,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
}
LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
Expand Down Expand Up @@ -650,7 +636,6 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(

// Set to Reg the same type as ResVReg has.
auto MRI = MIRBuilder.getMRI();
assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
if (Reg != ResVReg) {
LLT RegLLTy =
LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
Expand Down Expand Up @@ -706,8 +691,9 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
bool EmitIR) {
assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
"Invalid array element type");
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
Register NumElementsVReg =
buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
Expand Down Expand Up @@ -1188,14 +1174,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
if (ResVReg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
ResVReg = createTypeVReg(MIRBuilder);
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
SPIRVType *SpirvTy =
MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(ElemType))
.addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
.addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, true));
DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
return SpirvTy;
}
Expand Down Expand Up @@ -1386,8 +1373,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII);
Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII);
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addUse(getSPIRVTypeID(BaseType))
Expand Down Expand Up @@ -1436,7 +1423,7 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
Register Res = DT.find(UV, CurMF);
if (Res.isValid())
return Res;
LLT LLTy = LLT::scalar(32);
LLT LLTy = LLT::scalar(64);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
Expand All @@ -1451,3 +1438,61 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
*ST.getRegisterInfo(), *ST.getRegBankInfo());
return Res;
}

const TargetRegisterClass *
SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const {
unsigned Opcode = SpvType->getOpcode();
switch (Opcode) {
case SPIRV::OpTypeFloat:
return &SPIRV::fIDRegClass;
case SPIRV::OpTypePointer:
return &SPIRV::pIDRegClass;
case SPIRV::OpTypeVector: {
SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
if (ElemOpcode == SPIRV::OpTypeFloat)
return &SPIRV::vfIDRegClass;
if (ElemOpcode == SPIRV::OpTypePointer)
return &SPIRV::vpIDRegClass;
return &SPIRV::vIDRegClass;
}
}
return &SPIRV::iIDRegClass;
}

inline unsigned getAS(SPIRVType *SpvType) {
return storageClassToAddressSpace(
static_cast<SPIRV::StorageClass::StorageClass>(
SpvType->getOperand(1).getImm()));
}

LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
unsigned Opcode = SpvType ? SpvType->getOpcode() : 0;
switch (Opcode) {
case SPIRV::OpTypeInt:
case SPIRV::OpTypeFloat:
case SPIRV::OpTypeBool:
return LLT::scalar(getScalarOrVectorBitWidth(SpvType));
case SPIRV::OpTypePointer:
return LLT::pointer(getAS(SpvType), getPointerSize());
case SPIRV::OpTypeVector: {
SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
LLT ET;
switch (ElemType ? ElemType->getOpcode() : 0) {
case SPIRV::OpTypePointer:
ET = LLT::pointer(getAS(ElemType), getPointerSize());
break;
case SPIRV::OpTypeInt:
case SPIRV::OpTypeFloat:
case SPIRV::OpTypeBool:
ET = LLT::scalar(getScalarOrVectorBitWidth(ElemType));
break;
default:
ET = LLT::scalar(64);
}
return LLT::fixed_vector(
static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
}
}
return LLT::scalar(64);
}
Loading

0 comments on commit 5bf3c4d

Please sign in to comment.