Skip to content

Commit

Permalink
[mlir][mlprogram] Add mlprogram-pipeline-globals optimization pass
Browse files Browse the repository at this point in the history
Added pass optimizes MLProgram global operations by reducing to only
the minimal load/store operations for global tensors. This avoids
unnecessary global operations throughout a program and potentially
improves operation gusion.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D159228
  • Loading branch information
rsuderman committed Sep 19, 2023
1 parent b2ef297 commit cbd4750
Show file tree
Hide file tree
Showing 10 changed files with 582 additions and 4 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name MLProgram)
add_public_tablegen_target(MLIRMLProgramPassIncGen)
add_dependencies(mlir-headers MLIRMLProgramPassIncGen)

add_mlir_doc(Passes MLProgramPasses ./ -gen-pass-doc)
35 changes: 35 additions & 0 deletions mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace ml_program {

#define GEN_PASS_DECL
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"

} // namespace ml_program
} // namespace mlir

#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
let summary = "Optimize `ml_program` global operations for read and store";
let description = [{
`ml_program`'s load and store operations can be optimized for
write-write or write-read sets of operations. This allows known
tensors to not be re-read when the value is already known in IR.

The pass is designed to handle both nested regions and function calls
safely.
}];
let constructor = "mlir::ml_program::createMLProgramPipelineGlobalsPass()";
}

#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
Expand Down Expand Up @@ -72,6 +73,7 @@ inline void registerAllPasses() {
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
ml_program::registerMLProgramPasses();
registerSCFPasses();
registerShapePasses();
spirv::registerSPIRVPasses();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/MLProgram/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
20 changes: 16 additions & 4 deletions mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,14 @@ LogicalResult GlobalOp::verify() {
//===----------------------------------------------------------------------===//

GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
for (auto parent = getOperation()->getParentOp(); parent;
parent = parent->getParentOp()) {
if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
parent, getGlobalAttr())) {
return nearest;
}
}
return {};
}

LogicalResult
Expand Down Expand Up @@ -253,8 +259,14 @@ GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//

GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
for (auto parent = getOperation()->getParentOp(); parent;) {
if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
parent, getGlobalAttr())) {
return nearest;
}
parent = parent->getParentOp();
}
return {};
}

LogicalResult
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
add_mlir_dialect_library(MLIRMLProgramTransforms
PipelineGlobalOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram/Transforms

DEPENDS
MLIRMLProgramPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRMLProgramDialect
MLIRPass
)
234 changes: 234 additions & 0 deletions mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MLProgram/Transforms/Passes.h"

#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace ml_program {
#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"

namespace {

class MLProgramPipelineGlobals
: public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
public:
void runOnOperation() override;

private:
LogicalResult buildGlobalMap(ModuleOp op);

void ProcessBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
llvm::DenseSet<SymbolRefAttr> &symbolStore);

llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
};

// Traverses upwards searchign for the operation mapped by the symbol.
static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
for (auto op = baseOp; op; op = op->getParentOp()) {
auto lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
if (lookup)
return lookup;
}
return nullptr;
}

// Builds map from a symbol to MLProgram global symbols loaded or stored
// during processing.
LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
auto res = module->walk([&](Operation *op) {
if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
auto callable = caller.getCallableForCallee();
// For now we do not know how to handle Value based tracing, so fail.
if (mlir::isa<Value>(callable)) {
return WalkResult::interrupt();
}

auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
auto func = getFromSymbol(op, symbol);
callableMap[symbol] = func;
}
return WalkResult::advance();
});

if (res.wasInterrupted()) {
return failure();
}

// First grab all symbols loaded or stored by each function. This
// will not handle calls initially.
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
for (auto callable : callableMap) {
llvm::DenseSet<SymbolRefAttr> loadSymbols;
llvm::DenseSet<SymbolRefAttr> storeSymbols;

callable.getSecond()->walk(
[&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });

callable.getSecond()->walk(
[&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });

opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
}

// For each callable function we find each global loaded/stored within the
// function or a nested called function. This includes recursion checking to
// avoid infinitely recursing.
for (auto callable : callableMap) {
SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
llvm::DenseSet<SymbolRefAttr> loadSymbols;
llvm::DenseSet<SymbolRefAttr> storeSymbols;

for (size_t i = 0; i < work.size(); ++i) {
callableMap[work[i]]->walk([&](CallOpInterface call) {
auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
if (!visited.contains(symbol)) {
visited.insert(symbol);
work.push_back(symbol);
}
});

for (auto load : opLoadSymbols[work[i]])
loadSymbols.insert(load);

for (auto store : opStoreSymbols[work[i]])
storeSymbols.insert(store);
}

loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
}

return success();
}

// Process each operation in the block deleting unneeded loads / stores,
// recursing on subblocks and checking function calls.
void MLProgramPipelineGlobals::ProcessBlock(
Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
llvm::DenseSet<SymbolRefAttr> &symbolStore) {

llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
llvm::SmallVector<Operation *> toDelete;
for (auto &op : block) {
// If this is a global load, remap to a previous value if known
// and delete this load. Remember that this value is the currently
// known load.
if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
auto ref = load.getGlobal();
symbolLoad.insert(ref);
if (previousLoads.contains(ref)) {
toDelete.push_back(&op);
load.getResult().replaceAllUsesWith(previousLoads[ref]);
} else {
previousLoads[ref] = load.getResult();
}
continue;
}

// Delete a previous store if it exists and is not needed, update
// the most recent known value for this global ref.
if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
auto ref = store.getGlobal();
symbolStore.insert(ref);
if (previousStores.contains(ref)) {
toDelete.push_back(previousStores.find(ref)->getSecond());
}

previousLoads[ref] = store.getValue();
previousStores[ref] = &op;
continue;
}

// If a function is called, clear known values for loads/stores used by
// the function or its sub-functions.
if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
auto loadSymbols =
loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
auto storeSymbols =
storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];

for (auto sym : loadSymbols) {
previousStores.erase(sym);
}

for (auto sym : storeSymbols) {
previousLoads.erase(sym);
previousStores.erase(sym);
}
continue;
}

// If the op has sub-regions, recurse inside. We make no guarantees whether
// the recursion occurs.
llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
llvm::DenseSet<SymbolRefAttr> opSymbolStore;
for (auto &region : op.getRegions()) {
for (auto &block : region) {
ProcessBlock(block, opSymbolLoad, opSymbolStore);
}
}

// Update current state from the subblock.
for (auto change : opSymbolLoad) {
symbolLoad.insert(change);
previousStores.erase(change);
}

for (auto change : opSymbolStore) {
symbolStore.insert(change);
previousLoads.erase(change);
previousStores.erase(change);
}
}

for (auto op : toDelete) {
op->erase();
}
}

void MLProgramPipelineGlobals::runOnOperation() {
auto targetOp = getOperation();
if (failed(buildGlobalMap(targetOp))) {
return;
}

for (auto &funcOp : *targetOp.getBody()) {
for (auto &region : funcOp.getRegions()) {
for (auto &block : region.getBlocks()) {
llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
llvm::DenseSet<SymbolRefAttr> symbolsStored;
ProcessBlock(block, symbolsLoaded, symbolsStored);
}
}
}
}

} // namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>>
createMLProgramPipelineGlobalsPass() {
return std::make_unique<MLProgramPipelineGlobals>();
}

} // namespace ml_program
} // namespace mlir
Loading

0 comments on commit cbd4750

Please sign in to comment.