Skip to content

Commit

Permalink
[PIR] polish the pir interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Nov 27, 2023
1 parent 3af9eb7 commit ef583ba
Show file tree
Hide file tree
Showing 17 changed files with 52 additions and 51 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void GroupOp::Build(pir::Builder &builder, // NOLINT
argument.AddOutput(op.operand(i).type());
}
}
argument.AddRegion()->push_back(block.release());
argument.AddRegion().push_back(block.release());
}

pir::Block *GroupOp::block() {
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
"equal. but they are %u and 0, respectively",
argument.output_types.size()));
}
argument.AddRegion()->push_back(true_block.release());
argument.AddRegion()->push_back(false_block.release());
argument.AddRegion().push_back(true_block.release());
argument.AddRegion().push_back(false_block.release());
argument.AddInput(cond);
}

Expand Down Expand Up @@ -237,10 +237,10 @@ void WhileOp::Build(pir::Builder &builder, // NOLINT
}
argument.AddRegion(nullptr);
}
pir::Block *WhileOp::body_block() {
pir::Block &WhileOp::body_block() {
pir::Region &body_region = (*this)->region(0);
if (body_region.empty()) body_region.emplace_back();
return &body_region.front();
return body_region.front();
}
pir::Value WhileOp::cond() { return (*this)->operand_source(0); }

Expand All @@ -259,11 +259,11 @@ void WhileOp::Print(pir::IrPrinter &printer) {
[&]() { os << ", "; });
os << "] { \n ^";
pir::PrintInterleave(
body_block()->args_begin(),
body_block()->args_end(),
body_block().args_begin(),
body_block().args_end(),
[&](pir::Value v) { printer.PrintValue(v); },
[&]() { os << ", "; });
for (auto &item : *body_block()) {
for (auto &item : body_block()) {
os << "\n ";
printer.PrintOperation(&item);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class WhileOp : public pir::Op<WhileOp> {
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
const std::vector<pir::Value> &inputs);
pir::Block *body_block();
pir::Block &body_block();
pir::Value cond();
void Print(pir::IrPrinter &printer); // NOLINT
void VerifySig() {}
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ class InplacePass : public pir::Pass {
void Run(pir::Operation* op) override {
auto module_op = op->dyn_cast<pir::ModuleOp>();
IR_ENFORCE(module_op, "inplace_pass should run on module op.");
auto* block = module_op.block();
auto& block = module_op.block();

auto inplace_ops = details::GetInplaceOps(block);
auto inplace_ops = details::GetInplaceOps(&block);

for (auto kv : inplace_ops) {
VLOG(6) << "Do inplace for: "
Expand All @@ -458,8 +458,8 @@ class InplacePass : public pir::Pass {
.dyn_cast<pir::StrAttribute>()
.AsString();
pir::Block::Iterator insert_pos =
std::find(block->begin(), block->end(), *kv.first);
IR_ENFORCE(insert_pos != block->end(),
std::find(block.begin(), block.end(), *kv.first);
IR_ENFORCE(insert_pos != block.end(),
"Operator %s not found in block.",
kv.first->name());

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class ParamsSyncAmongDevicesPass : public pir::Pass {
module_op,
phi::errors::PreconditionNotMet(
"params_sync_among_devices_pass should run on module op."));
auto* block = module_op.block();
for (auto& inner_op : *block) {
auto& block = module_op.block();
for (auto& inner_op : block) {
if (inner_op.isa<pir::ParameterOp>()) {
std::string param_name = inner_op.attributes()
.at("parameter_name")
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -925,15 +925,15 @@ void HandleForWhileOp(
pir::Builder builder(ctx, block);
auto base_while_op = op_item->dyn_cast<WhileOp>();
auto new_while_op = builder.Build<WhileOp>(cond_val, vec_in);
pir::Block* body_block = new_while_op.body_block();
pir::Block& body_block = new_while_op.body_block();
for (size_t i = 0; i < vec_in.size(); ++i) {
auto block_arg = body_block->AddArgument(vec_in[i].type());
(*map_value_pair)[base_while_op.body_block()->argument(i)] = block_arg;
auto block_arg = body_block.AddArgument(vec_in[i].type());
(*map_value_pair)[base_while_op.body_block().argument(i)] = block_arg;
}

// process body block
ProcessBlock(place,
base_while_op.body_block(),
&base_while_op.body_block(),
body_block,
ctx,
map_op_pair,
Expand Down
12 changes: 7 additions & 5 deletions paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,13 @@ Program *ModuleOp::program() {
iter->second.dyn_cast<PointerAttribute>().data());
}

Block *ModuleOp::block() {
assert(operation() != nullptr);
assert(operation()->num_regions() == 1);
assert(operation()->region(0).size() == 1);
return &operation()->region(0).front();
Block &ModuleOp::block() {
IR_ENFORCE(operation()->num_regions(),
"The region size of ModuleOp must be equal to 1.");
auto &region = (*this)->region(0);
IR_ENFORCE(region.size() == 1,
"The region size of ModuleOp must be equal to 1.");
return region.front();
}

ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/builtin_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class IR_API ModuleOp : public pir::Op<ModuleOp> {
static const char *attributes_name[attributes_num];
void VerifySig() const;
Program *program();
Block *block();
Block &block();

//
// As the top operation, ModuleOp only support create&destroye through
Expand Down
5 changes: 2 additions & 3 deletions paddle/pir/core/interface_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ConstructInterfacesOrTraits {
/// Construct method for interfaces.
static void interface(InterfaceSet &interface_set) { // NOLINT
(void)std::initializer_list<int>{
0, (PlacementConstrctInterface<Args>(interface_set), 0)...};
0, (ConstrctInterface<Args>(interface_set), 0)...};
}

/// Construct method for traits.
Expand All @@ -38,8 +38,7 @@ class ConstructInterfacesOrTraits {
private:
/// Placement new interface.
template <typename T>
static void PlacementConstrctInterface(
InterfaceSet &interface_set) { // NOLINT
static void ConstrctInterface(InterfaceSet &interface_set) { // NOLINT
InterfaceValue val = InterfaceValue::
Get<ConcreteT, T, typename T::template Model<ConcreteT>>();
auto suceess = interface_set.insert(std::move(val)).second;
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/operation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ OperationArgument::OperationArgument(IrContext* ir_context,
info = ir_context->GetRegisteredOpInfo(name);
}

Region* OperationArgument::AddRegion() {
Region& OperationArgument::AddRegion() {
regions.emplace_back(new Region);
return regions.back().get();
return *regions.back();
}

/// Take a region that should be attached to the Operation.
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct OperationArgument {
/// Create a region that should be attached to the operation. These regions
/// can be filled in immediately without waiting for Operation to be
/// created. When it is, the region bodies will be transferred.
Region* AddRegion();
Region& AddRegion();

/// Take a region that should be attached to the Operation. The body of the
/// region will be transferred when the Operation is created. If the
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class IR_API Program {

static std::unique_ptr<Program> Parse(std::istream& is, IrContext* ctx);

Block* block() { return module_.block(); }
const Block* block() const { return module_op().block(); }
Block* block() { return &module_.block(); }
const Block* block() const { return &module_op().block(); }

Parameter* GetParameter(const std::string& name) const;
void SetParameter(const std::string& name,
Expand Down
10 changes: 5 additions & 5 deletions paddle/pir/dialect/shape/utils/shape_optimization_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ bool CompareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT
}

SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) {
for (auto& op : *(m.block())) {
for (auto& op : m.block()) {
if (op.isa<shape::FuncOp>()) {
symbol_table_ = SymbolTable(&op);
return;
}
}
Builder builder = Builder(m_.ir_context(), m_.block(), m_.block()->begin());
Builder builder = Builder(m_.ir_context(), &m_.block(), m_.block().begin());
shape::FuncOp func = builder.Build<shape::FuncOp>();
symbol_table_ = SymbolTable(func);
}
Expand Down Expand Up @@ -473,7 +473,7 @@ bool SymbolicDimMgr::Save() {
};

// TODO(zhangbopd): update attributes attached in DenseTensorType
for (auto& op : *(m_.block())) {
for (auto& op : m_.block()) {
if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue;
auto attrs =
op.attribute<ArrayAttribute>(SymbolicDimOp::GetSymbolicDimAttrName());
Expand All @@ -499,7 +499,7 @@ bool SymbolicDimMgr::Save() {
used_symbol_names.push_back(sym.GetSymName());
}
};
for (auto& op : *(m_.block())) {
for (auto& op : m_.block()) {
if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue;
auto attrs =
op.attribute<ArrayAttribute>(SymbolicDimOp::GetSymbolicDimAttrName());
Expand Down Expand Up @@ -559,7 +559,7 @@ bool SymbolicDimMgr::Save() {
name_to_symbol[name] = op;
}

for (auto& op : *(m_.block())) {
for (auto& op : m_.block()) {
if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue;
auto attrs =
op.attribute<ArrayAttribute>(SymbolicDimOp::GetSymbolicDimAttrName());
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/dialect/shape/utils/shape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ bool ShapeAnalysis::IsProductEqual(
ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m)
: m_(m), mgr_(m) {
mgr_.Load();
for (auto& op : *(m_.block())) {
for (auto& op : m.block()) {
auto tie_shape_op = op.dyn_cast<shape::TieShapeOp>();
if (!tie_shape_op) continue;
Value result = tie_shape_op.input();
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/new_executor/standalone_executor_pir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,10 @@ TEST(StandaloneExecutor, while_op) {
builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, ten});

// { i = i + 1}
pir::Block* body_block = while_op.body_block();
auto body_i_argument = body_block->AddArgument(i.type());
auto body_ten_argument = body_block->AddArgument(ten.type());
builder.SetInsertionPointToStart(body_block);
pir::Block& body_block = while_op.body_block();
auto body_i_argument = body_block.AddArgument(i.type());
auto body_ten_argument = body_block.AddArgument(ten.type());
builder.SetInsertionPointToStart(&body_block);
auto one =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
.out();
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/pir/control_flow_dialect/while_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ TEST(while_op_test, base) {
builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, ten});

// { i = i + 1}
pir::Block* body_block = while_op.body_block();
auto body_i_argument = body_block->AddArgument(i.type());
auto body_ten_argument = body_block->AddArgument(ten.type());
builder.SetInsertionPointToStart(body_block);
pir::Block& body_block = while_op.body_block();
auto body_i_argument = body_block.AddArgument(i.type());
auto body_ten_argument = body_block.AddArgument(ten.type());
builder.SetInsertionPointToStart(&body_block);
auto one =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
.out();
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/pir/core/program_translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ TEST(OperatorDialectTest, WhileOpProgram) {
EXPECT_TRUE(op.isa<paddle::dialect::WhileOp>());
EXPECT_EQ(op.num_regions(), 1u);
// body block
pir::Block *body_block =
pir::Block &body_block =
op.dyn_cast<paddle::dialect::WhileOp>().body_block();
size_t body_id = 0;
for (auto &op1 : *body_block) {
for (auto &op1 : body_block) {
if (body_id == 0) {
EXPECT_TRUE(op1.isa<paddle::dialect::FullOp>());
}
Expand All @@ -307,10 +307,10 @@ TEST(OperatorDialectTest, WhileOpProgram) {
EXPECT_TRUE(op1.isa<paddle::dialect::LessThanOp>());
}
if (body_id == 3) {
pir::Block *body_body_block =
pir::Block &body_body_block =
op1.dyn_cast<paddle::dialect::WhileOp>().body_block();
size_t body_body_id = 0;
for (auto &op2 : *body_body_block) {
for (auto &op2 : body_body_block) {
if (body_body_id == 0) {
EXPECT_TRUE(op2.isa<paddle::dialect::FullOp>());
}
Expand Down

0 comments on commit ef583ba

Please sign in to comment.