Skip to content

Commit

Permalink
[IR] add the erase api for region&block. (#54844)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Jun 26, 2023
1 parent ef445ec commit e50266f
Show file tree
Hide file tree
Showing 18 changed files with 248 additions and 194 deletions.
7 changes: 7 additions & 0 deletions paddle/ir/core/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/region.h"

Expand All @@ -32,6 +33,12 @@ Block::iterator Block::insert(const_iterator iterator, Operation *op) {
return iter;
}

Block::iterator Block::erase(const_iterator position) {
IR_ENFORCE((*position)->GetParent() == this, "iterator not own this block.");
(*position)->Destroy();
return ops_.erase(position);
}

void Block::clear() {
while (!empty()) {
ops_.back()->Destroy();
Expand Down
1 change: 1 addition & 0 deletions paddle/ir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class IR_API Block {
void push_back(Operation *op);
void push_front(Operation *op);
iterator insert(const_iterator iterator, Operation *op);
iterator erase(const_iterator position);
void clear();
operator Region::iterator() { return position_; }

Expand Down
33 changes: 31 additions & 2 deletions paddle/ir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ Program *ModuleOp::program() {
Block *ModuleOp::block() {
assert(operation() != nullptr);
assert(operation()->num_regions() == 1);
assert(operation()->GetRegion(0).size() == 1);
return operation()->GetRegion(0).front();
assert(operation()->region(0).size() == 1);
return operation()->region(0).front();
}

ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
Expand Down Expand Up @@ -71,6 +71,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};

void GetParameterOp::Build(Builder &builder,
OperationArgument &argument,
const std::string &name,
Type type) {
argument.attributes[attributes_name[0]] =
ir::StrAttribute::get(builder.ir_context(), name);
argument.output_types.emplace_back(type);
}

void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
Expand All @@ -90,6 +99,14 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const char *SetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};

void SetParameterOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name) {
argument.AddOperand(parameter);
argument.AddAttribute(attributes_name[0],
ir::StrAttribute::get(builder.ir_context(), name));
}
void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
Expand All @@ -106,6 +123,18 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
}

void CombineOp::Build(Builder &builder,
OperationArgument &argument,
const std::vector<ir::OpResult> &inputs) {
argument.inputs = inputs;
std::vector<ir::Type> inputs_type(inputs.size());
for (size_t idx = 0; idx < inputs.size(); ++idx) {
inputs_type[idx] = inputs[idx].type();
}
argument.output_types.emplace_back(
ir::VectorType::get(builder.ir_context(), inputs_type));
}

void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
Expand Down
16 changes: 14 additions & 2 deletions paddle/ir/core/builtin_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> {
static const char *name() { return "builtin.get_parameter"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::string &name,
Type type);
static void Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const ir::AttributeMap &attributes);
};

Expand All @@ -69,6 +73,10 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> {
static const char *name() { return "builtin.set_parameter"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name);
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
Expand All @@ -87,6 +95,10 @@ class IR_API CombineOp : public ir::Op<CombineOp> {

static constexpr const char **attributes_name = nullptr;

static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::vector<ir::OpResult> &inputs);

static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
void IrPrinter::PrintProgram(Program* program) {
auto top_level_op = program->module_op();
for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
auto& region = top_level_op->GetRegion(i);
auto& region = top_level_op->region(i);
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
os << "{\n";
Expand Down Expand Up @@ -153,7 +153,7 @@ void IrPrinter::PrintFullOperation(Operation* op) {
os << newline;
}
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
auto& region = op->region(i);
PrintRegion(region);
}
}
Expand Down
73 changes: 28 additions & 45 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Operation *Operation::Create(OperationArgument &&argument) {
argument.regions.size());

for (size_t index = 0; index < argument.regions.size(); ++index) {
op->GetRegion(index).TakeBody(std::move(*argument.regions[index]));
op->region(index).TakeBody(std::move(*argument.regions[index]));
}
return op;
}
Expand Down Expand Up @@ -103,17 +103,35 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
return op;
}

// Call destructors for OpResults, Operation, and OpOperands in sequence, and
// finally free memory.
// Call destructors for Region , OpResults, Operation, and OpOperands in
// sequence, and finally free memory.
void Operation::Destroy() {
// Deconstruct Regions.
// 1. Deconstruct Regions.
if (num_regions_ > 0) {
for (size_t idx = 0; idx < num_regions_; idx++) {
regions_[idx].~Region();
}
}

// 1. Get aligned_ptr by result_num.
// 2. Deconstruct Result.
for (size_t idx = 0; idx < num_results_; ++idx) {
detail::OpResultImpl *impl = result(idx).impl();
IR_ENFORCE(impl->use_empty(), "operation destroyed but still has uses.");
if (detail::OpOutlineResultImpl::classof(*impl)) {
static_cast<detail::OpOutlineResultImpl *>(impl)->~OpOutlineResultImpl();
} else {
static_cast<detail::OpInlineResultImpl *>(impl)->~OpInlineResultImpl();
}
}

// 3. Deconstruct Operation.
this->~Operation();

// 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
operand(idx).impl()->~OpOperandImpl();
}
// 5. Free memory.
uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
size_t result_mem_size =
Expand All @@ -122,46 +140,11 @@ void Operation::Destroy() {
(num_results_ - max_inline_result_num) +
sizeof(detail::OpInlineResultImpl) * max_inline_result_num
: sizeof(detail::OpInlineResultImpl) * num_results_;
char *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;
// 2.1. Deconstruct OpResult.
char *base_ptr = aligned_ptr;
for (size_t idx = num_results_; idx > 0; idx--) {
// release the uses of this result
detail::OpOperandImpl *first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
while (first_use != nullptr) {
first_use->RemoveFromUdChain();
first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
}
// destory the result
if (idx > max_inline_result_num) {
reinterpret_cast<detail::OpOutlineResultImpl *>(base_ptr)
->~OpOutlineResultImpl();
base_ptr += sizeof(detail::OpOutlineResultImpl);
} else {
reinterpret_cast<detail::OpInlineResultImpl *>(base_ptr)
->~OpInlineResultImpl();
base_ptr += sizeof(detail::OpInlineResultImpl);
}
}
// 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) {
IR_THROW("Operation address error");
}
reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation);
// 2.3. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
reinterpret_cast<detail::OpOperandImpl *>(base_ptr)->~OpOperandImpl();
base_ptr += sizeof(detail::OpOperandImpl);
}
// 3. Free memory.
VLOG(4) << "Destroy an Operation: {ptr = "
<< reinterpret_cast<void *>(aligned_ptr)
void *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;

VLOG(4) << "Destroy an Operation: {ptr = " << aligned_ptr
<< ", size = " << result_mem_size << "}";
aligned_free(reinterpret_cast<void *>(aligned_ptr));
aligned_free(aligned_ptr);
}

IrContext *Operation::ir_context() const { return info_.ir_context(); }
Expand Down Expand Up @@ -231,7 +214,7 @@ Program *Operation::GetParentProgram() {
return module_op ? module_op.program() : nullptr;
}

Region &Operation::GetRegion(unsigned index) {
Region &Operation::region(unsigned index) {
assert(index < num_regions_ && "invalid region index");
return regions_[index];
}
Expand Down
8 changes: 5 additions & 3 deletions paddle/ir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class IR_API alignas(8) Operation final {

OpOperand operand(uint32_t index) const;

/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);

void Print(std::ostream &os);

const AttributeMap &attributes() const { return attributes_; }
Expand Down Expand Up @@ -95,11 +98,10 @@ class IR_API alignas(8) Operation final {

Program *GetParentProgram();

/// Returns the region held by this operation at position 'index'.
Region &GetRegion(unsigned index);

operator Block::iterator() { return position_; }

operator Block::const_iterator() const { return position_; }

private:
Operation(const AttributeMap &attribute,
ir::OpInfo op_info,
Expand Down
6 changes: 6 additions & 0 deletions paddle/ir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ struct OperationArgument {
info(info),
regions(std::move(regions)) {}

/// Add Operand.
void AddOperand(OpResult operand) { inputs.emplace_back(operand); }

template <class InputIt>
void AddOperands(InputIt first, InputIt last);

/// Add Output.
void AddOutput(Type type) { output_types.emplace_back(type); }

template <class InputIt>
void AddTypes(InputIt first, InputIt last);

Expand Down
7 changes: 7 additions & 0 deletions paddle/ir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/ir/core/region.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"

namespace ir {
Region::~Region() { clear(); }
Expand All @@ -29,6 +30,12 @@ Region::iterator Region::insert(const_iterator position, Block *block) {
block->SetParent(this, iter);
return iter;
}

Region::iterator Region::erase(const_iterator position) {
IR_ENFORCE((*position)->GetParent() == this, "iterator not own this region.");
delete *position;
return blocks_.erase(position);
}
void Region::TakeBody(Region &&other) {
clear();
blocks_.swap(other.blocks_);
Expand Down
1 change: 1 addition & 0 deletions paddle/ir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class IR_API Region {
void emplace_back();
void push_front(Block *block);
iterator insert(const_iterator position, Block *block);
iterator erase(const_iterator position);
void clear();

void TakeBody(Region &&other);
Expand Down
31 changes: 21 additions & 10 deletions paddle/ir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
}
OpOperand::operator bool() const { return impl_ && impl_->source(); }

OpOperand OpOperand::next_use() const { return impl_->next_use(); }
OpOperand OpOperand::next_use() const { return impl()->next_use(); }

Value OpOperand::source() const { return impl_->source(); }
Value OpOperand::source() const { return impl()->source(); }

void OpOperand::set_source(Value value) {
IR_ENFORCE(impl_, "Can't set source for a null value.");
impl_->set_source(value);
}
Type OpOperand::type() const { return source().type(); }

void OpOperand::set_source(Value value) { impl()->set_source(value); }

Operation *OpOperand::owner() const { return impl()->owner(); }

Operation *OpOperand::owner() const { return impl_->owner(); }
void OpOperand::RemoveFromUdChain() { return impl()->RemoveFromUdChain(); }

detail::OpOperandImpl *OpOperand::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while operand is null.");
return impl_;
}
// Value
Value::Value(const detail::ValueImpl *impl)
: impl_(const_cast<detail::ValueImpl *>(impl)) {}
Expand Down Expand Up @@ -84,13 +89,18 @@ void Value::ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const {
for (auto it = begin(); it != end();) {
auto cur = it++;
if (should_replace(*cur)) {
cur->set_source(new_value);
if (should_replace(*it)) {
(it++)->set_source(new_value);
}
}
}

void Value::ReplaceAllUsesWith(Value new_value) const {
for (auto it = begin(); it != end();) {
(it++)->set_source(new_value);
}
}

detail::ValueImpl *Value::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return impl_;
Expand All @@ -106,6 +116,7 @@ Operation *OpResult::owner() const { return impl()->owner(); }
uint32_t OpResult::GetResultIndex() const { return impl()->GetResultIndex(); }

detail::OpResultImpl *OpResult::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return reinterpret_cast<detail::OpResultImpl *>(impl_);
}

Expand Down
Loading

0 comments on commit e50266f

Please sign in to comment.