From cbd475040f8952cfc55b9e13dd5ce6c4f6434cd3 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 18 Sep 2023 16:31:38 -0700 Subject: [PATCH] [mlir][mlprogram] Add `mlprogram-pipeline-globals` optimization pass Added pass optimizes MLProgram global operations by reducing to only the minimal load/store operations for global tensors. This avoids unnecessary global operations throughout a program and potentially improves operation gusion. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D159228 --- .../mlir/Dialect/MLProgram/CMakeLists.txt | 1 + .../MLProgram/Transforms/CMakeLists.txt | 6 + .../Dialect/MLProgram/Transforms/Passes.h | 35 +++ .../Dialect/MLProgram/Transforms/Passes.td | 27 ++ mlir/include/mlir/InitAllPasses.h | 2 + mlir/lib/Dialect/MLProgram/CMakeLists.txt | 1 + .../lib/Dialect/MLProgram/IR/MLProgramOps.cpp | 20 +- .../MLProgram/Transforms/CMakeLists.txt | 14 + .../Transforms/PipelineGlobalOps.cpp | 234 +++++++++++++++++ .../Dialect/MLProgram/pipeline-globals.mlir | 246 ++++++++++++++++++ 10 files changed, 582 insertions(+), 4 deletions(-) create mode 100644 mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h create mode 100644 mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td create mode 100644 mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt create mode 100644 mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp create mode 100644 mlir/test/Dialect/MLProgram/pipeline-globals.mlir diff --git a/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt index f33061b2d87cff..9f57627c321fb0 100644 --- a/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt new file mode 100644 index 00000000000000..c5c11f17a9fa97 --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name MLProgram) +add_public_tablegen_target(MLIRMLProgramPassIncGen) +add_dependencies(mlir-headers MLIRMLProgramPassIncGen) + +add_mlir_doc(Passes MLProgramPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h new file mode 100644 index 00000000000000..894e35e52724e9 --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace ml_program { + +#define GEN_PASS_DECL +#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +std::unique_ptr> createMLProgramPipelineGlobalsPass(); + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" + +} // namespace ml_program +} // namespace mlir + +#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td new file mode 100644 index 00000000000000..defe8191cb905d --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td @@ -0,0 +1,27 @@ +//===-- Passes.td - pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES +#define MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> { + let summary = "Optimize `ml_program` global operations for read and store"; + let description = [{ + `ml_program`'s load and store operations can be optimized for + write-write or write-read sets of operations. This allows known + tensors to not be re-read when the value is already known in IR. + + The pass is designed to handle both nested regions and function calls + safely. + }]; + let constructor = "mlir::ml_program::createMLProgramPipelineGlobalsPass()"; +} + +#endif // MLIR_DIALECT_MLPROGRAM_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 8a45da7d1b982f..5489a13a8040bd 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -26,6 +26,7 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MLProgram/Transforms/Passes.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Transforms/Passes.h" @@ -72,6 +73,7 @@ inline void registerAllPasses() { LLVM::registerLLVMPasses(); math::registerMathPasses(); memref::registerMemRefPasses(); + ml_program::registerMLProgramPasses(); registerSCFPasses(); registerShapePasses(); spirv::registerSPIRVPasses(); diff --git a/mlir/lib/Dialect/MLProgram/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/CMakeLists.txt index f33061b2d87cff..9f57627c321fb0 100644 --- a/mlir/lib/Dialect/MLProgram/CMakeLists.txt +++ b/mlir/lib/Dialect/MLProgram/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp index f8f75495660395..5352b7b0454fd1 100644 --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -178,8 +178,14 @@ LogicalResult GlobalOp::verify() { //===----------------------------------------------------------------------===// GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { - return symbolTable.lookupNearestSymbolFrom( - getOperation()->getParentOp(), getGlobalAttr()); + for (auto parent = getOperation()->getParentOp(); parent; + parent = parent->getParentOp()) { + if (auto nearest = symbolTable.lookupNearestSymbolFrom( + parent, getGlobalAttr())) { + return nearest; + } + } + return {}; } LogicalResult @@ -253,8 +259,14 @@ GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { //===----------------------------------------------------------------------===// GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { - return symbolTable.lookupNearestSymbolFrom( - getOperation()->getParentOp(), getGlobalAttr()); + for (auto parent = getOperation()->getParentOp(); parent;) { + if (auto nearest = symbolTable.lookupNearestSymbolFrom( + parent, getGlobalAttr())) { + return nearest; + } + parent = parent->getParentOp(); + } + return {}; } LogicalResult diff --git a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt new file mode 100644 index 00000000000000..db567b62e0e747 --- /dev/null +++ b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRMLProgramTransforms + PipelineGlobalOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram/Transforms + + DEPENDS + MLIRMLProgramPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRMLProgramDialect + MLIRPass +) diff --git a/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp b/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp new file mode 100644 index 00000000000000..7e00a731f6d731 --- /dev/null +++ b/mlir/lib/Dialect/MLProgram/Transforms/PipelineGlobalOps.cpp @@ -0,0 +1,234 @@ +//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MLProgram/Transforms/Passes.h" + +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/MLProgram/Transforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace ml_program { +#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS +#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc" + +namespace { + +class MLProgramPipelineGlobals + : public impl::MLProgramPipelineGlobalsBase { +public: + void runOnOperation() override; + +private: + LogicalResult buildGlobalMap(ModuleOp op); + + void ProcessBlock(Block &block, llvm::DenseSet &symbolLoad, + llvm::DenseSet &symbolStore); + + llvm::DenseMap> loadSymbolsMap; + llvm::DenseMap> storeSymbolsMap; +}; + +// Traverses upwards searchign for the operation mapped by the symbol. +static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) { + for (auto op = baseOp; op; op = op->getParentOp()) { + auto lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol); + if (lookup) + return lookup; + } + return nullptr; +} + +// Builds map from a symbol to MLProgram global symbols loaded or stored +// during processing. +LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) { + llvm::DenseMap callableMap; + auto res = module->walk([&](Operation *op) { + if (auto caller = mlir::dyn_cast(op)) { + auto callable = caller.getCallableForCallee(); + // For now we do not know how to handle Value based tracing, so fail. + if (mlir::isa(callable)) { + return WalkResult::interrupt(); + } + + auto symbol = mlir::dyn_cast(callable); + auto func = getFromSymbol(op, symbol); + callableMap[symbol] = func; + } + return WalkResult::advance(); + }); + + if (res.wasInterrupted()) { + return failure(); + } + + // First grab all symbols loaded or stored by each function. This + // will not handle calls initially. + llvm::DenseMap> opLoadSymbols; + llvm::DenseMap> opStoreSymbols; + for (auto callable : callableMap) { + llvm::DenseSet loadSymbols; + llvm::DenseSet storeSymbols; + + callable.getSecond()->walk( + [&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); }); + + callable.getSecond()->walk( + [&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); }); + + opLoadSymbols[callable.getFirst()] = std::move(loadSymbols); + opStoreSymbols[callable.getFirst()] = std::move(storeSymbols); + } + + // For each callable function we find each global loaded/stored within the + // function or a nested called function. This includes recursion checking to + // avoid infinitely recursing. + for (auto callable : callableMap) { + SymbolRefAttr thisSymbol = llvm::dyn_cast(callable.first); + llvm::SmallVector work = {thisSymbol}; + llvm::DenseSet visited = {thisSymbol}; + llvm::DenseSet loadSymbols; + llvm::DenseSet storeSymbols; + + for (size_t i = 0; i < work.size(); ++i) { + callableMap[work[i]]->walk([&](CallOpInterface call) { + auto symbol = dyn_cast(call.getCallableForCallee()); + if (!visited.contains(symbol)) { + visited.insert(symbol); + work.push_back(symbol); + } + }); + + for (auto load : opLoadSymbols[work[i]]) + loadSymbols.insert(load); + + for (auto store : opStoreSymbols[work[i]]) + storeSymbols.insert(store); + } + + loadSymbolsMap[thisSymbol] = std::move(loadSymbols); + storeSymbolsMap[thisSymbol] = std::move(storeSymbols); + } + + return success(); +} + +// Process each operation in the block deleting unneeded loads / stores, +// recursing on subblocks and checking function calls. +void MLProgramPipelineGlobals::ProcessBlock( + Block &block, llvm::DenseSet &symbolLoad, + llvm::DenseSet &symbolStore) { + + llvm::DenseMap previousLoads; + llvm::DenseMap previousStores; + llvm::SmallVector toDelete; + for (auto &op : block) { + // If this is a global load, remap to a previous value if known + // and delete this load. Remember that this value is the currently + // known load. + if (auto load = mlir::dyn_cast(op)) { + auto ref = load.getGlobal(); + symbolLoad.insert(ref); + if (previousLoads.contains(ref)) { + toDelete.push_back(&op); + load.getResult().replaceAllUsesWith(previousLoads[ref]); + } else { + previousLoads[ref] = load.getResult(); + } + continue; + } + + // Delete a previous store if it exists and is not needed, update + // the most recent known value for this global ref. + if (auto store = mlir::dyn_cast(op)) { + auto ref = store.getGlobal(); + symbolStore.insert(ref); + if (previousStores.contains(ref)) { + toDelete.push_back(previousStores.find(ref)->getSecond()); + } + + previousLoads[ref] = store.getValue(); + previousStores[ref] = &op; + continue; + } + + // If a function is called, clear known values for loads/stores used by + // the function or its sub-functions. + if (auto call = mlir::dyn_cast(op)) { + auto loadSymbols = + loadSymbolsMap[dyn_cast(call.getCallableForCallee())]; + auto storeSymbols = + storeSymbolsMap[dyn_cast(call.getCallableForCallee())]; + + for (auto sym : loadSymbols) { + previousStores.erase(sym); + } + + for (auto sym : storeSymbols) { + previousLoads.erase(sym); + previousStores.erase(sym); + } + continue; + } + + // If the op has sub-regions, recurse inside. We make no guarantees whether + // the recursion occurs. + llvm::DenseSet opSymbolLoad; + llvm::DenseSet opSymbolStore; + for (auto ®ion : op.getRegions()) { + for (auto &block : region) { + ProcessBlock(block, opSymbolLoad, opSymbolStore); + } + } + + // Update current state from the subblock. + for (auto change : opSymbolLoad) { + symbolLoad.insert(change); + previousStores.erase(change); + } + + for (auto change : opSymbolStore) { + symbolStore.insert(change); + previousLoads.erase(change); + previousStores.erase(change); + } + } + + for (auto op : toDelete) { + op->erase(); + } +} + +void MLProgramPipelineGlobals::runOnOperation() { + auto targetOp = getOperation(); + if (failed(buildGlobalMap(targetOp))) { + return; + } + + for (auto &funcOp : *targetOp.getBody()) { + for (auto ®ion : funcOp.getRegions()) { + for (auto &block : region.getBlocks()) { + llvm::DenseSet symbolsLoaded; + llvm::DenseSet symbolsStored; + ProcessBlock(block, symbolsLoaded, symbolsStored); + } + } + } +} + +} // namespace + +std::unique_ptr> +createMLProgramPipelineGlobalsPass() { + return std::make_unique(); +} + +} // namespace ml_program +} // namespace mlir diff --git a/mlir/test/Dialect/MLProgram/pipeline-globals.mlir b/mlir/test/Dialect/MLProgram/pipeline-globals.mlir new file mode 100644 index 00000000000000..a5c9b3e890558e --- /dev/null +++ b/mlir/test/Dialect/MLProgram/pipeline-globals.mlir @@ -0,0 +1,246 @@ +// RUN: mlir-opt -split-input-file -pass-pipeline="builtin.module(mlprogram-pipeline-globals)" --allow-unregistered-dialect %s + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @global_double_load +func.func @global_double_load() { + // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable + // CHECK-NOT: ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + %1 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]], %[[LOAD]]) + %2 = "unregistered.dummy"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %2 : tensor<4xi32> + func.return +} + +// ----- + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @global_double_store +func.func @global_double_store() { + // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %1 : tensor<4xi32> + + // CHECK-NOT: ml_program.global_store + ml_program.global_store @global_variable = %1 : tensor<4xi32> + func.return +} + +// ----- + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @global_store_load +func.func @global_store_load() { + // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + // CHECK: %[[DUMMY2:.+]] = "unregistered.dummy"(%[[DUMMY2]]) + %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>) + ml_program.global_store @global_variable = %1 : tensor<4xi32> + %2 = ml_program.global_load @global_variable : tensor<4xi32> + %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY2]] + ml_program.global_store @global_variable = %3 : tensor<4xi32> + func.return +} + +// ----- + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @global_store_load_region +func.func @global_store_load_region() { + // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %1 : tensor<4xi32> + + // CHECK: "unregistered.dummy2" + "unregistered.dummy2"() ({ + ^bb(): + %cst = arith.constant dense<0> : tensor<4xi32> + // CHECK: ml_program.global_store @global_variable + ml_program.global_store @global_variable = %cst : tensor<4xi32> + "unregistered.terminator"() : () -> () + }) : () -> () + + // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable + %2 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY2:.+]] = "unregistered.dummy"(%[[LOAD]]) + %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY2]] + ml_program.global_store @global_variable = %3 : tensor<4xi32> + func.return +} + +// ----- + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @interrupt +func.func @interrupt() { + %cst = arith.constant dense<0> : tensor<4xi32> + // CHECK: ml_program.global_store + ml_program.global_store @global_variable = %cst : tensor<4xi32> + func.return +} + +// CHECK-LABEL: @call_global_store +func.func @call_global_store() { + // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %1 : tensor<4xi32> + call @interrupt() : () -> () + + // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable + %2 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %3 : tensor<4xi32> + func.return +} + + +// ----- + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @interrupt_indirect +func.func @interrupt_indirect() { + %cst = arith.constant dense<0> : tensor<4xi32> + // CHECK: ml_program.global_store + ml_program.global_store @global_variable = %cst : tensor<4xi32> + func.return +} + +// CHECK-LABEL: @interrupt +func.func @interrupt() { + call @interrupt_indirect() : () -> () + func.return +} + +// CHECK-LABEL: @call_indirect_store +func.func @call_indirect_store() { + // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %1 : tensor<4xi32> + call @interrupt() : () -> () + + // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable + %2 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %3 : tensor<4xi32> + func.return +} + + +// ----- + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @interrupt_indirect +func.func @interrupt_indirect() -> tensor<4xi32> { + // CHECK: ml_program.global_load + %0 = ml_program.global_load @global_variable : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: @interrupt +func.func @interrupt() { + %0 = call @interrupt_indirect() : () -> (tensor<4xi32>) + "unregistered.dummy"(%0) : (tensor<4xi32>) -> () + func.return +} + +// CHECK-LABEL: @call_indirect_load +func.func @call_indirect_load() { + // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %1 : tensor<4xi32> + call @interrupt() : () -> () + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %2 = ml_program.global_load @global_variable : tensor<4xi32> + %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %3 : tensor<4xi32> + func.return +} + +// ----- + +// CHECK-LABEL: @global_variable +ml_program.global private mutable @global_variable(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK-LABEL: @call_recursive +func.func @call_recursive() { + // CHECK: %[[LOAD:.+]] = ml_program.global_load @global_variable + %0 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %1 = "unregistered.dummy"(%0) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %1 : tensor<4xi32> + call @call_recursive() : () -> () + + // CHECK: %[[LOAD:.+]] ml_program.global_load @global_variable + %2 = ml_program.global_load @global_variable : tensor<4xi32> + + // CHECK: %[[DUMMY:.+]] = "unregistered.dummy"(%[[LOAD]]) + %3 = "unregistered.dummy"(%2) : (tensor<4xi32>) -> (tensor<4xi32>) + + // CHECK: ml_program.global_store @global_variable %[[DUMMY]] + ml_program.global_store @global_variable = %3 : tensor<4xi32> + func.return +}