Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] add the erase api for region&block. #54844

Merged
merged 1 commit into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/ir/core/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/region.h"

Expand All @@ -32,6 +33,12 @@ Block::iterator Block::insert(const_iterator iterator, Operation *op) {
return iter;
}

Block::iterator Block::erase(const_iterator position) {
IR_ENFORCE((*position)->GetParent() == this, "iterator not own this block.");
(*position)->Destroy();
return ops_.erase(position);
}

void Block::clear() {
while (!empty()) {
ops_.back()->Destroy();
Expand Down
1 change: 1 addition & 0 deletions paddle/ir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class IR_API Block {
void push_back(Operation *op);
void push_front(Operation *op);
iterator insert(const_iterator iterator, Operation *op);
iterator erase(const_iterator position);
void clear();
operator Region::iterator() { return position_; }

Expand Down
33 changes: 31 additions & 2 deletions paddle/ir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ Program *ModuleOp::program() {
Block *ModuleOp::block() {
assert(operation() != nullptr);
assert(operation()->num_regions() == 1);
assert(operation()->GetRegion(0).size() == 1);
return operation()->GetRegion(0).front();
assert(operation()->region(0).size() == 1);
return operation()->region(0).front();
}

ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
Expand Down Expand Up @@ -71,6 +71,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};

void GetParameterOp::Build(Builder &builder,
OperationArgument &argument,
const std::string &name,
Type type) {
argument.attributes[attributes_name[0]] =
ir::StrAttribute::get(builder.ir_context(), name);
argument.output_types.emplace_back(type);
}

void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
Expand All @@ -90,6 +99,14 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const char *SetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};

void SetParameterOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name) {
argument.AddOperand(parameter);
argument.AddAttribute(attributes_name[0],
ir::StrAttribute::get(builder.ir_context(), name));
}
void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
Expand All @@ -106,6 +123,18 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
}

void CombineOp::Build(Builder &builder,
OperationArgument &argument,
const std::vector<ir::OpResult> &inputs) {
argument.inputs = inputs;
std::vector<ir::Type> inputs_type(inputs.size());
for (size_t idx = 0; idx < inputs.size(); ++idx) {
inputs_type[idx] = inputs[idx].type();
}
argument.output_types.emplace_back(
ir::VectorType::get(builder.ir_context(), inputs_type));
}

void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
Expand Down
16 changes: 14 additions & 2 deletions paddle/ir/core/builtin_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> {
static const char *name() { return "builtin.get_parameter"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::string &name,
Type type);
static void Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const ir::AttributeMap &attributes);
};

Expand All @@ -69,6 +73,10 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> {
static const char *name() { return "builtin.set_parameter"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name);
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
Expand All @@ -87,6 +95,10 @@ class IR_API CombineOp : public ir::Op<CombineOp> {

static constexpr const char **attributes_name = nullptr;

static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::vector<ir::OpResult> &inputs);

static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
Expand Down
4 changes: 2 additions & 2 deletions paddle/ir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
void IrPrinter::PrintProgram(Program* program) {
auto top_level_op = program->module_op();
for (size_t i = 0; i < top_level_op->num_regions(); ++i) {
auto& region = top_level_op->GetRegion(i);
auto& region = top_level_op->region(i);
for (auto it = region.begin(); it != region.end(); ++it) {
auto* block = *it;
os << "{\n";
Expand Down Expand Up @@ -153,7 +153,7 @@ void IrPrinter::PrintFullOperation(Operation* op) {
os << newline;
}
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
auto& region = op->region(i);
PrintRegion(region);
}
}
Expand Down
73 changes: 28 additions & 45 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Operation *Operation::Create(OperationArgument &&argument) {
argument.regions.size());

for (size_t index = 0; index < argument.regions.size(); ++index) {
op->GetRegion(index).TakeBody(std::move(*argument.regions[index]));
op->region(index).TakeBody(std::move(*argument.regions[index]));
}
return op;
}
Expand Down Expand Up @@ -103,17 +103,35 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
return op;
}

// Call destructors for OpResults, Operation, and OpOperands in sequence, and
// finally free memory.
// Call destructors for Region , OpResults, Operation, and OpOperands in
// sequence, and finally free memory.
void Operation::Destroy() {
// Deconstruct Regions.
// 1. Deconstruct Regions.
if (num_regions_ > 0) {
for (size_t idx = 0; idx < num_regions_; idx++) {
regions_[idx].~Region();
}
}

// 1. Get aligned_ptr by result_num.
// 2. Deconstruct Result.
for (size_t idx = 0; idx < num_results_; ++idx) {
detail::OpResultImpl *impl = result(idx).impl();
IR_ENFORCE(impl->use_empty(), "operation destroyed but still has uses.");
if (detail::OpOutlineResultImpl::classof(*impl)) {
static_cast<detail::OpOutlineResultImpl *>(impl)->~OpOutlineResultImpl();
} else {
static_cast<detail::OpInlineResultImpl *>(impl)->~OpInlineResultImpl();
}
}

// 3. Deconstruct Operation.
this->~Operation();

// 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
operand(idx).impl()->~OpOperandImpl();
}
// 5. Free memory.
uint32_t max_inline_result_num =
detail::OpResultImpl::GetMaxInlineResultIndex() + 1;
size_t result_mem_size =
Expand All @@ -122,46 +140,11 @@ void Operation::Destroy() {
(num_results_ - max_inline_result_num) +
sizeof(detail::OpInlineResultImpl) * max_inline_result_num
: sizeof(detail::OpInlineResultImpl) * num_results_;
char *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;
// 2.1. Deconstruct OpResult.
char *base_ptr = aligned_ptr;
for (size_t idx = num_results_; idx > 0; idx--) {
// release the uses of this result
detail::OpOperandImpl *first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
while (first_use != nullptr) {
first_use->RemoveFromUdChain();
first_use =
reinterpret_cast<detail::OpResultImpl *>(base_ptr)->first_use();
}
// destory the result
if (idx > max_inline_result_num) {
reinterpret_cast<detail::OpOutlineResultImpl *>(base_ptr)
->~OpOutlineResultImpl();
base_ptr += sizeof(detail::OpOutlineResultImpl);
} else {
reinterpret_cast<detail::OpInlineResultImpl *>(base_ptr)
->~OpInlineResultImpl();
base_ptr += sizeof(detail::OpInlineResultImpl);
}
}
// 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) {
IR_THROW("Operation address error");
}
reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation);
// 2.3. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
reinterpret_cast<detail::OpOperandImpl *>(base_ptr)->~OpOperandImpl();
base_ptr += sizeof(detail::OpOperandImpl);
}
// 3. Free memory.
VLOG(4) << "Destroy an Operation: {ptr = "
<< reinterpret_cast<void *>(aligned_ptr)
void *aligned_ptr = reinterpret_cast<char *>(this) - result_mem_size;

VLOG(4) << "Destroy an Operation: {ptr = " << aligned_ptr
<< ", size = " << result_mem_size << "}";
aligned_free(reinterpret_cast<void *>(aligned_ptr));
aligned_free(aligned_ptr);
}

IrContext *Operation::ir_context() const { return info_.ir_context(); }
Expand Down Expand Up @@ -231,7 +214,7 @@ Program *Operation::GetParentProgram() {
return module_op ? module_op.program() : nullptr;
}

Region &Operation::GetRegion(unsigned index) {
Region &Operation::region(unsigned index) {
assert(index < num_regions_ && "invalid region index");
return regions_[index];
}
Expand Down
8 changes: 5 additions & 3 deletions paddle/ir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class IR_API alignas(8) Operation final {

OpOperand operand(uint32_t index) const;

/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);

void Print(std::ostream &os);

const AttributeMap &attributes() const { return attributes_; }
Expand Down Expand Up @@ -95,11 +98,10 @@ class IR_API alignas(8) Operation final {

Program *GetParentProgram();

/// Returns the region held by this operation at position 'index'.
Region &GetRegion(unsigned index);

operator Block::iterator() { return position_; }

operator Block::const_iterator() const { return position_; }

private:
Operation(const AttributeMap &attribute,
ir::OpInfo op_info,
Expand Down
6 changes: 6 additions & 0 deletions paddle/ir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ struct OperationArgument {
info(info),
regions(std::move(regions)) {}

/// Add Operand.
void AddOperand(OpResult operand) { inputs.emplace_back(operand); }

template <class InputIt>
void AddOperands(InputIt first, InputIt last);

/// Add Output.
void AddOutput(Type type) { output_types.emplace_back(type); }

template <class InputIt>
void AddTypes(InputIt first, InputIt last);

Expand Down
7 changes: 7 additions & 0 deletions paddle/ir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/ir/core/region.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"

namespace ir {
Region::~Region() { clear(); }
Expand All @@ -29,6 +30,12 @@ Region::iterator Region::insert(const_iterator position, Block *block) {
block->SetParent(this, iter);
return iter;
}

Region::iterator Region::erase(const_iterator position) {
IR_ENFORCE((*position)->GetParent() == this, "iterator not own this region.");
delete *position;
return blocks_.erase(position);
}
void Region::TakeBody(Region &&other) {
clear();
blocks_.swap(other.blocks_);
Expand Down
1 change: 1 addition & 0 deletions paddle/ir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class IR_API Region {
void emplace_back();
void push_front(Block *block);
iterator insert(const_iterator position, Block *block);
iterator erase(const_iterator position);
void clear();

void TakeBody(Region &&other);
Expand Down
31 changes: 21 additions & 10 deletions paddle/ir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
}
OpOperand::operator bool() const { return impl_ && impl_->source(); }

OpOperand OpOperand::next_use() const { return impl_->next_use(); }
OpOperand OpOperand::next_use() const { return impl()->next_use(); }

Value OpOperand::source() const { return impl_->source(); }
Value OpOperand::source() const { return impl()->source(); }

void OpOperand::set_source(Value value) {
IR_ENFORCE(impl_, "Can't set source for a null value.");
impl_->set_source(value);
}
Type OpOperand::type() const { return source().type(); }

void OpOperand::set_source(Value value) { impl()->set_source(value); }

Operation *OpOperand::owner() const { return impl()->owner(); }

Operation *OpOperand::owner() const { return impl_->owner(); }
void OpOperand::RemoveFromUdChain() { return impl()->RemoveFromUdChain(); }

detail::OpOperandImpl *OpOperand::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while operand is null.");
return impl_;
}
// Value
Value::Value(const detail::ValueImpl *impl)
: impl_(const_cast<detail::ValueImpl *>(impl)) {}
Expand Down Expand Up @@ -84,13 +89,18 @@ void Value::ReplaceUsesWithIf(
Value new_value,
const std::function<bool(OpOperand)> &should_replace) const {
for (auto it = begin(); it != end();) {
auto cur = it++;
if (should_replace(*cur)) {
cur->set_source(new_value);
if (should_replace(*it)) {
(it++)->set_source(new_value);
}
}
}

void Value::ReplaceAllUsesWith(Value new_value) const {
for (auto it = begin(); it != end();) {
(it++)->set_source(new_value);
}
}

detail::ValueImpl *Value::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return impl_;
Expand All @@ -106,6 +116,7 @@ Operation *OpResult::owner() const { return impl()->owner(); }
uint32_t OpResult::GetResultIndex() const { return impl()->GetResultIndex(); }

detail::OpResultImpl *OpResult::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while value is null.");
return reinterpret_cast<detail::OpResultImpl *>(impl_);
}

Expand Down
Loading