Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Reset third_party/llvm-project:
4f15267d3dd797a15901fe9352f0d5fa121b9095 (2023-02-15 16:52:25 +0100):
[libc++][NFC] Replace _LIBCPP_STD_VER > x with _LIBCPP_STD_VER >= x
* Updated to tensorflow/tensorflow@75eaca4
* Updated to tensorflow/mlir-hlo@a913e03
* Cherry picked MLIR bug fix llvm/llvm-project@3cf7f22
* Cherry-picked MLIR bug fix llvm/llvm-project@e44f405
* Used `llvm/TargetParser/Host.h` to replace `llvm/Support/Host.h`
* Used `llvm::bit_vector` to replace `llvm::PowerOf2Floor`
* Updated GPU memory space handling in converting to LLVM
* Run `python compiler/src/iree/compiler/API2/generate_exports.py`
* Fixed bufferization issue in transform dialect path

---------

Co-authored-by: Hanhan Wang <hanchung@google.com>
Co-authored-by: Thomas Raoux <thomasraoux@google.com>
Co-authored-by: Matthias Springer <springerm@google.com>
  • Loading branch information
4 people authored Feb 16, 2023
1 parent a04c262 commit 1290401
Show file tree
Hide file tree
Showing 30 changed files with 153 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/CrashRecoveryContext.h"
#include "llvm/Support/Host.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/PluginLoader.h"
#include "llvm/Support/Process.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/TargetParser/Triple.h"

using namespace lld;
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/API2/api_exports.c
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ extern void mlirF64TypeGet();
extern void mlirFlatSymbolRefAttrGet();
extern void mlirFlatSymbolRefAttrGetValue();
extern void mlirFloat8E4M3FNTypeGet();
extern void mlirFloat8E4M3FNUZTypeGet();
extern void mlirFloat8E5M2FNUZTypeGet();
extern void mlirFloat8E5M2TypeGet();
extern void mlirFloatAttrDoubleGet();
extern void mlirFloatAttrDoubleGetChecked();
Expand Down Expand Up @@ -533,7 +535,9 @@ extern void mlirTypeIsAF16();
extern void mlirTypeIsAF32();
extern void mlirTypeIsAF64();
extern void mlirTypeIsAFloat8E4M3FN();
extern void mlirTypeIsAFloat8E4M3FNUZ();
extern void mlirTypeIsAFloat8E5M2();
extern void mlirTypeIsAFloat8E5M2FNUZ();
extern void mlirTypeIsAFunction();
extern void mlirTypeIsAIndex();
extern void mlirTypeIsAInteger();
Expand Down Expand Up @@ -880,6 +884,8 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&mlirFlatSymbolRefAttrGet;
x += (uintptr_t)&mlirFlatSymbolRefAttrGetValue;
x += (uintptr_t)&mlirFloat8E4M3FNTypeGet;
x += (uintptr_t)&mlirFloat8E4M3FNUZTypeGet;
x += (uintptr_t)&mlirFloat8E5M2FNUZTypeGet;
x += (uintptr_t)&mlirFloat8E5M2TypeGet;
x += (uintptr_t)&mlirFloatAttrDoubleGet;
x += (uintptr_t)&mlirFloatAttrDoubleGetChecked;
Expand Down Expand Up @@ -1111,7 +1117,9 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&mlirTypeIsAF32;
x += (uintptr_t)&mlirTypeIsAF64;
x += (uintptr_t)&mlirTypeIsAFloat8E4M3FN;
x += (uintptr_t)&mlirTypeIsAFloat8E4M3FNUZ;
x += (uintptr_t)&mlirTypeIsAFloat8E5M2;
x += (uintptr_t)&mlirTypeIsAFloat8E5M2FNUZ;
x += (uintptr_t)&mlirTypeIsAFunction;
x += (uintptr_t)&mlirTypeIsAIndex;
x += (uintptr_t)&mlirTypeIsAInteger;
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API2/api_exports.def
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ EXPORTS
mlirFlatSymbolRefAttrGet
mlirFlatSymbolRefAttrGetValue
mlirFloat8E4M3FNTypeGet
mlirFloat8E4M3FNUZTypeGet
mlirFloat8E5M2FNUZTypeGet
mlirFloat8E5M2TypeGet
mlirFloatAttrDoubleGet
mlirFloatAttrDoubleGetChecked
Expand Down Expand Up @@ -525,7 +527,9 @@ EXPORTS
mlirTypeIsAF32
mlirTypeIsAF64
mlirTypeIsAFloat8E4M3FN
mlirTypeIsAFloat8E4M3FNUZ
mlirTypeIsAFloat8E5M2
mlirTypeIsAFloat8E5M2FNUZ
mlirTypeIsAFunction
mlirTypeIsAIndex
mlirTypeIsAInteger
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API2/api_exports.ld
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ VER_0 {
mlirFlatSymbolRefAttrGet;
mlirFlatSymbolRefAttrGetValue;
mlirFloat8E4M3FNTypeGet;
mlirFloat8E4M3FNUZTypeGet;
mlirFloat8E5M2FNUZTypeGet;
mlirFloat8E5M2TypeGet;
mlirFloatAttrDoubleGet;
mlirFloatAttrDoubleGetChecked;
Expand Down Expand Up @@ -526,7 +528,9 @@ VER_0 {
mlirTypeIsAF32;
mlirTypeIsAF64;
mlirTypeIsAFloat8E4M3FN;
mlirTypeIsAFloat8E4M3FNUZ;
mlirTypeIsAFloat8E5M2;
mlirTypeIsAFloat8E5M2FNUZ;
mlirTypeIsAFunction;
mlirTypeIsAIndex;
mlirTypeIsAInteger;
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/API2/api_exports.macos.lst
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ _mlirF64TypeGet
_mlirFlatSymbolRefAttrGet
_mlirFlatSymbolRefAttrGetValue
_mlirFloat8E4M3FNTypeGet
_mlirFloat8E4M3FNUZTypeGet
_mlirFloat8E5M2FNUZTypeGet
_mlirFloat8E5M2TypeGet
_mlirFloatAttrDoubleGet
_mlirFloatAttrDoubleGetChecked
Expand Down Expand Up @@ -524,7 +526,9 @@ _mlirTypeIsAF16
_mlirTypeIsAF32
_mlirTypeIsAF64
_mlirTypeIsAFloat8E4M3FN
_mlirTypeIsAFloat8E4M3FNUZ
_mlirTypeIsAFloat8E5M2
_mlirTypeIsAFloat8E5M2FNUZ
_mlirTypeIsAFunction
_mlirTypeIsAIndex
_mlirTypeIsAInteger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ include "mlir/IR/PatternBase.td"

// Canonicalize unnecessary tensor_load when the load is used just for
// an extract
def : Pat<(Tensor_ExtractOp (Bufferization_ToTensorOp $value), $indices),
def : Pat<(Tensor_ExtractOp(Bufferization_ToTensorOp $value, $restrict,
$writable),
$indices),
(LoadOp $value, $indices, ConstBoolAttrFalse)>;

#endif // IREE_COMPILER_CODEGEN_COMMON_FOLDTENSOREXTRACTOP
Original file line number Diff line number Diff line change
Expand Up @@ -1054,12 +1054,7 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply(
options.memCpyFn = memCpyFn;
options.testAnalysisOnly = getTestAnalysisOnly();
options.printConflicts = getPrintConflicts();
WalkResult res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
if (failed(runIREEOneShotBufferize(moduleOp, options)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (res.wasInterrupted())
if (failed(runIREEOneShotBufferize(state.getTopLevel(), options)))
return DiagnosedSilenceableFailure::definiteFailure();

// Early exit if test_analysis_only is set.
Expand All @@ -1071,7 +1066,7 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply(
// 3. Post-bufferization passes are fine.
PassManager pm(getContext());
addIREEPostBufferizationPasses(pm);
res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
WalkResult res = state.getTopLevel()->walk([&](ModuleOp moduleOp) {
if (failed(pm.run(moduleOp))) {
getOperation()->emitError()
<< "failed to post-bufferization passes on module:\n"
Expand Down
4 changes: 1 addition & 3 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,7 @@ MemRefDescriptor HALDispatchABI::loadBinding(Operation *forOp, int64_t ordinal,
// Cast to the desired memref element type.
auto elementType = typeConverter->convertType(memRefType.getElementType());
Value typedPtrValue = builder.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(elementType, memRefType.getMemorySpaceAsInt()),
basePtrValue);
loc, LLVM::LLVMPointerType::get(elementType), basePtrValue);

// Construct the MemRefDescriptor type based on the information we have.
// NOTE: we could use the binding length to clamp this/check that the
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,8 @@ static SmallVector<int64_t> getDefaultDistributedLoopTileSizes(
}
// Fallback to power of 2 if there's no hint or can't find the ideal size.
if (vectorSize <= 1 || candidateTileSize == 1) {
candidateTileSize =
std::max<int64_t>(llvm::PowerOf2Floor(targetSize), minTileSizes[i]);
candidateTileSize = std::max<int64_t>(
llvm::bit_floor<uint64_t>(targetSize), minTileSizes[i]);
}

// Limit the workload per workgroup to the default being the max to keep the
Expand Down
28 changes: 23 additions & 5 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,12 @@ struct ConvertSharedMemAllocOp : public OpRewritePattern<memref::AllocOp> {

LogicalResult matchAndRewrite(memref::AllocOp allocOp,
PatternRewriter &rewriter) const override {
if (allocOp.getType().getMemorySpaceAsInt() != 3) return failure();
auto addressSpace = allocOp.getType()
.getMemorySpace()
.dyn_cast_or_null<gpu::AddressSpaceAttr>();
if (!addressSpace ||
addressSpace.getValue() != gpu::GPUDialect::getWorkgroupAddressSpace())
return failure();
ArrayRef<int64_t> shape = allocOp.getType().getShape();
if (llvm::any_of(shape,
[](int64_t dim) { return dim == ShapedType::kDynamic; })) {
Expand Down Expand Up @@ -266,8 +271,7 @@ class ConvertFunc : public ConvertToLLVMPattern {
funcOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
auto memrefType = subspanOp.getType().cast<MemRefType>();
Type elType = memrefType.getElementType();
auto llvmType =
LLVM::LLVMPointerType::get(elType, memrefType.getMemorySpaceAsInt());
auto llvmType = LLVM::LLVMPointerType::get(elType);
llvmInputTypes[argMapping[SetBinding(subspanOp.getSet(),
subspanOp.getBinding())]] = llvmType;
});
Expand Down Expand Up @@ -388,8 +392,7 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
loc, llvmBufferBasei8Ptr.getType(), llvmBufferBasei8Ptr,
adaptor.getByteOffset());
}
auto llvmPtrType = LLVM::LLVMPointerType::get(
memrefType.getElementType(), memrefType.getMemorySpaceAsInt());
auto llvmPtrType = LLVM::LLVMPointerType::get(memrefType.getElementType());
Value llvmBufferBasePtr =
rewriter.create<LLVM::BitcastOp>(loc, llvmPtrType, llvmBufferBasei8Ptr);
if (memrefType.hasStaticShape()) {
Expand Down Expand Up @@ -519,5 +522,20 @@ std::unique_ptr<OperationPass<ModuleOp>> createTestLLVMGPULegalizePass() {
return std::make_unique<TestLLVMGPULegalizeOpPass>();
}

static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}

void populateGpuMemorySpaceAttributeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addTypeAttributeConversion(
[mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
unsigned addressSpace = mapping(memorySpace);
return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
addressSpace);
});
}

} // namespace iree_compiler
} // namespace mlir
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"

namespace mlir {
namespace gpu {
enum class AddressSpace : uint32_t;
}
namespace iree_compiler {

/// Verifies compatibility of the module for application of the LLVM
Expand All @@ -33,6 +36,11 @@ void populateConvertSharedMemoryAllocOps(RewritePatternSet &patterns);

void ConvertToDynamicSharedMemory(ModuleOp moduleOp);

using MemorySpaceMapping =
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
void populateGpuMemorySpaceAttributeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping);

} // namespace iree_compiler
} // namespace mlir

Expand Down
32 changes: 25 additions & 7 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ struct DropSharedMemoryDeallocOp : public OpRewritePattern<memref::DeallocOp> {

LogicalResult matchAndRewrite(memref::DeallocOp op,
PatternRewriter &rewriter) const override {
unsigned addressSpace =
op.getMemref().getType().cast<MemRefType>().getMemorySpaceAsInt();
if (addressSpace == NVVM::NVVMMemorySpace::kSharedMemorySpace) {
rewriter.eraseOp(op);
return success();
}
return failure();
auto addressSpace = op.getMemref()
.getType()
.cast<MemRefType>()
.getMemorySpace()
.dyn_cast_or_null<gpu::AddressSpaceAttr>();
if (!addressSpace ||
addressSpace.getValue() != gpu::GPUDialect::getWorkgroupAddressSpace())
return failure();
rewriter.eraseOp(op);
return success();
}
};

Expand All @@ -72,6 +75,21 @@ struct ConvertToNVVMPass : public ConvertToNVVMBase<ConvertToNVVMPass> {
LowerToLLVMOptions options(m.getContext(), DataLayout(m));
options.overrideIndexBitwidth(64);
LLVMTypeConverter converter(m.getContext(), options);
populateGpuMemorySpaceAttributeConversions(
converter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
// Lowering for MMAMatrixType.
converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
return convertMMAToLLVMType(type);
Expand Down
13 changes: 13 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase<ConvertToROCDLPass> {
LowerToLLVMOptions options(m.getContext(), DataLayout(m));
options.overrideIndexBitwidth(64);
LLVMTypeConverter converter(m.getContext(), options);
populateGpuMemorySpaceAttributeConversions(
converter, [](gpu::AddressSpace space) {
switch (space) {
case gpu::AddressSpace::Global:
return 1;
case gpu::AddressSpace::Workgroup:
return 3;
case gpu::AddressSpace::Private:
return 5;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
// conversion pass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ struct LLVMGPUVectorToGPUPass
return signalPassFailure();
}

IRRewriter rewriter(&getContext());
if (llvmgpuUseMMASync) {
if (failed(convertVectorToNVVMCompatibleMMASync(funcOp))) {
if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) {
return signalPassFailure();
}
// Using TF32 for Float.
Expand All @@ -81,7 +82,9 @@ struct LLVMGPUVectorToGPUPass
return signalPassFailure();
}
} else {
convertVectorToMMAOps(funcOp);
if (failed(convertVectorToMMAOps(rewriter, funcOp))) {
return signalPassFailure();
}
}
createAsyncGroups(funcOp, llvmgpuUseMMASync);

Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ static void addLowerToLLVMGPUPasses(OpPassManager &pm, bool useROCM) {
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createLowerAffinePass());
pm.addPass(createGPULowerMemorySpaceAttributesPass());
// Strip out the debug info for the kernel as CUDA driver doesn't diggest PTX
// debug info well.
pm.addPass(createStripDebugInfoPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,14 @@ transform_dialect::VectorToMMAConversionOp::applyToOne(
mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
populatePrepareVectorToMMAPatterns(patterns, /*llvmgpuUseMMASync=*/false);
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) {
target->emitOpError("vector to mma preparation patterns failed to apply");
return emitDefaultDefiniteFailure(target);
}
IRRewriter rewriter(getContext());
if (failed(convertVectorToMMAOps(rewriter, target))) {
target->emitOpError("vector to mma patterns failed to apply");
return emitDefaultDefiniteFailure(target);
}
convertVectorToMMAOps(target);

results.push_back(target);
return DiagnosedSilenceableFailure::success();
Expand Down
Loading

0 comments on commit 1290401

Please sign in to comment.