Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]add block arguement. #57249

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
ctx_, defining_op_result, parameter_name_mappings_[var_name]);

pir::Block* block = program_->block();
pir::Block::iterator insert_pos = std::find(
pir::Block::Iterator insert_pos = std::find(
block->begin(), block->end(), defining_op_result.owner());

IR_ENFORCE(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class InplacePass : public pir::Pass {
.at("op_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
pir::Block::iterator insert_pos =
pir::Block::Iterator insert_pos =
std::find(block->begin(), block->end(), kv.first);
IR_ENFORCE(insert_pos != block->end(),
"Operator %s not found in block.",
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,7 @@ void BindOpResult(py::module *m) {
return paddle::dialect::greater_equal(self, other);
})
.def("__hash__",
[](OpResult &self) {
return std::hash<pir::Value>{}(self.dyn_cast<pir::Value>());
})
[](OpResult &self) { return std::hash<pir::Value>{}(self); })
.def("get_defining_op",
&OpResult::GetDefiningOp,
return_value_policy::reference)
Expand Down
21 changes: 17 additions & 4 deletions paddle/pir/core/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@

namespace pir {
Block::~Block() {
assert(use_empty() && "block destroyed still has uses.");
if (!use_empty()) {
LOG(FATAL) << "Destoryed a block that is still in use.";
}
clear();
ClearArguments();
}
void Block::push_back(Operation *op) { insert(ops_.end(), op); }

Expand All @@ -33,13 +36,13 @@ Operation *Block::GetParentOp() const {
return parent_ ? parent_->GetParent() : nullptr;
}

Block::iterator Block::insert(const_iterator iterator, Operation *op) {
Block::iterator iter = ops_.insert(iterator, op);
Block::Iterator Block::insert(ConstIterator iterator, Operation *op) {
Block::Iterator iter = ops_.insert(iterator, op);
op->SetParent(this, iter);
return iter;
}

Block::iterator Block::erase(const_iterator position) {
Block::Iterator Block::erase(ConstIterator position) {
IR_ENFORCE((*position)->GetParent() == this, "iterator not own this block.");
(*position)->Destroy();
return ops_.erase(position);
Expand Down Expand Up @@ -75,6 +78,16 @@ void Block::ResetOpListOrder(const OpListType &new_op_list) {
}
}

void Block::ClearArguments() {
for (auto &argument : arguments_) {
argument.Destroy();
}
arguments_.clear();
}
void Block::AddArgument(Type type) {
arguments_.emplace_back(BlockArgument::Create(type, this, arguments_.size()));
}

bool Block::TopoOrderCheck(const OpListType &op_list) {
std::unordered_set<Value> visited_values;
for (const Operation *op : op_list) {
Expand Down
59 changes: 46 additions & 13 deletions paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstddef>
#include <list>

#include "paddle/pir/core/block_argument.h"
#include "paddle/pir/core/block_operand.h"
#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/core/region.h"
Expand All @@ -29,9 +30,9 @@ class IR_API Block {
using OpListType = std::list<Operation *>;

public:
using iterator = OpListType::iterator;
using reverse_iterator = OpListType::reverse_iterator;
using const_iterator = OpListType::const_iterator;
using Iterator = OpListType::iterator;
using ReverseIterator = OpListType::reverse_iterator;
using ConstIterator = OpListType::const_iterator;

Block() = default;
~Block();
Expand All @@ -42,19 +43,19 @@ class IR_API Block {
bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); }

const_iterator begin() const { return ops_.begin(); }
const_iterator end() const { return ops_.end(); }
iterator begin() { return ops_.begin(); }
iterator end() { return ops_.end(); }
reverse_iterator rbegin() { return ops_.rbegin(); }
reverse_iterator rend() { return ops_.rend(); }
ConstIterator begin() const { return ops_.begin(); }
ConstIterator end() const { return ops_.end(); }
Iterator begin() { return ops_.begin(); }
Iterator end() { return ops_.end(); }
ReverseIterator rbegin() { return ops_.rbegin(); }
ReverseIterator rend() { return ops_.rend(); }

Operation *back() const { return ops_.back(); }
Operation *front() const { return ops_.front(); }
void push_back(Operation *op);
void push_front(Operation *op);
iterator insert(const_iterator iterator, Operation *op);
iterator erase(const_iterator position);
Iterator insert(ConstIterator iterator, Operation *op);
Iterator erase(ConstIterator position);
void clear();
operator Region::iterator() { return position_; }

Expand All @@ -73,6 +74,29 @@ class IR_API Block {
// This is a unsafe funcion, please use it carefully.
void ResetOpListOrder(const OpListType &new_op_list);

///
/// \brief Block argument management
///
using BlockArgListType = std::vector<BlockArgument>;
using ArgsIterator = BlockArgListType::iterator;

ArgsIterator args_begin() { return arguments_.begin(); }
ArgsIterator args_end() { return arguments_.end(); }
bool args_empty() const { return arguments_.empty(); }
uint32_t args_size() const { return arguments_.size(); }
BlockArgument argument(uint32_t index) { return arguments_[index]; }
Type argument_type(uint32_t index) const { return arguments_[index].type(); }

void ClearArguments();
void AddArgument(Type type);
template <class TypeIter>
void AddArguments(TypeIter first, TypeIter last);

template <class TypeContainer>
void AddArguments(const TypeContainer &container) {
AddArguments(container.begin(), container.end());
}

private:
Block(Block &) = delete;
Block &operator=(const Block &) = delete;
Expand All @@ -84,9 +108,18 @@ class IR_API Block {
static bool TopoOrderCheck(const OpListType &op_list);

private:
Region *parent_; // not owned
OpListType ops_; // owned
Region::iterator position_;
BlockOperand first_use_;
OpListType ops_; // owned
BlockArgListType arguments_; // owned
Region *parent_; // not owned
};

template <class TypeIter>
void Block::AddArguments(TypeIter first, TypeIter last) {
while (first != last) {
AddArgument(*first++);
}
}

} // namespace pir
95 changes: 95 additions & 0 deletions paddle/pir/core/block_argument.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pir/core/block_argument.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/value_impl.h"

#define CHECK_NULL_IMPL(func_name) \
IR_ENFORCE(impl_, "impl_ is null when called BlockArgument:" #func_name)

#define IMPL_ static_cast<detail::BlockArgumentImpl *>(impl_)

namespace pir {

namespace detail {
///
/// \brief BlockArgumentImpl is the implementation of an block argument.
///
class BlockArgumentImpl : public ValueImpl {
public:
static bool classof(const ValueImpl &value) {
return value.kind() == BLOCK_ARGUMENT_INDEX;
}

private:
BlockArgumentImpl(Type type, Block *owner, uint32_t index)
: ValueImpl(type, BLOCK_ARGUMENT_INDEX), owner_(owner), index_(index) {}

~BlockArgumentImpl();
// access construction and owner
friend BlockArgument;
Block *owner_;
uint32_t index_;
};

BlockArgumentImpl::~BlockArgumentImpl() {
if (!use_empty()) {
LOG(FATAL) << "Destoryed a blockargument that is still in use.";
}
}

} // namespace detail

BlockArgument::BlockArgument(detail::BlockArgumentImpl *impl) : Value(impl) {}

bool BlockArgument::classof(Value value) {
return value && detail::BlockArgumentImpl::classof(*value.impl());
}

Block *BlockArgument::owner() const {
CHECK_NULL_IMPL(owner);
return IMPL_->owner_;
}

uint32_t BlockArgument::arg_index() const {
CHECK_NULL_IMPL(arg_index);
return IMPL_->index_;
}

BlockArgument BlockArgument::Create(Type type, Block *owner, uint32_t index) {
return new detail::BlockArgumentImpl(type, owner, index);
}
/// Destroy the argument.
void BlockArgument::Destroy() {
if (impl_) {
LOG(WARNING) << "Destroying a null block argument.";
} else {
delete IMPL_;
}
}

void BlockArgument::set_arg_index(uint32_t index) {
CHECK_NULL_IMPL(set_arg_number);
IMPL_->index_ = index;
}

BlockArgument BlockArgument::dyn_cast_from(Value value) {
if (classof(value)) {
return static_cast<detail::BlockArgumentImpl *>(value.impl());
} else {
return nullptr;
}
}

} // namespace pir
55 changes: 55 additions & 0 deletions paddle/pir/core/block_argument.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/core/value.h"
namespace pir {
class Block;

namespace detail {
class BlockArgumentImpl;
} // namespace detail

///
/// \brief BlockArgument class represents the value defined by a result of
/// operation. This class only provides interfaces, for specific implementation,
/// see Impl class.
///
class IR_API BlockArgument : public Value {
public:
BlockArgument() = default;
Block *owner() const;
uint32_t arg_index() const;

private:
/// constructor
BlockArgument(detail::BlockArgumentImpl *impl); // NOLINT

/// create a new argument with the given type and owner.
static BlockArgument Create(Type type, Block *owner, uint32_t index);
/// Destroy the argument.
void Destroy();
/// set the position in the block argument list.
void set_arg_index(uint32_t index);
// Access create annd destroy.
friend Block;

// Access classof annd dyn_cast_from.
friend Value;
static bool classof(Value value);
static BlockArgument dyn_cast_from(Value value);
};

} // namespace pir
12 changes: 6 additions & 6 deletions paddle/pir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PointerAttribute;
///
class Builder {
public:
Builder(IrContext *context, Block *block, Block::iterator insert_point)
Builder(IrContext *context, Block *block, Block::Iterator insert_point)
: context_(context) {
SetInsertionPoint(block, insert_point);
}
Expand All @@ -57,10 +57,10 @@ class Builder {
: Builder(context, block, block->end()) {}

explicit Builder(IrContext *context)
: Builder(context, nullptr, Block::iterator{}) {}
: Builder(context, nullptr, Block::Iterator{}) {}

/// Set the insertion point to the specified location.
void SetInsertionPoint(Block *block, Block::iterator insert_point) {
void SetInsertionPoint(Block *block, Block::Iterator insert_point) {
// TODO(liuyuanle): check that insertPoint is in this rather than some other
// block.
this->block_ = block;
Expand All @@ -70,13 +70,13 @@ class Builder {
/// Set the insertion point to the specified operation, which will cause
/// subsequent insertions to go right before it.
void SetInsertionPoint(Operation *op) {
SetInsertionPoint(op->GetParent(), Block::iterator{*op});
SetInsertionPoint(op->GetParent(), Block::Iterator{*op});
}

/// Set the insertion point to the node after the specified operation, which
/// will cause subsequent insertions to go right after it.
void SetInsertionPointAfter(Operation *op) {
SetInsertionPoint(op->GetParent(), std::next(Block::iterator{*op}));
SetInsertionPoint(op->GetParent(), std::next(Block::Iterator{*op}));
}

/// Set the insertion point to the start of the specified block.
Expand Down Expand Up @@ -138,7 +138,7 @@ class Builder {
IrContext *context_;
Block *block_;
// The insertion point within the list that this builder is inserting before.
Block::iterator insert_point_;
Block::Iterator insert_point_;
};

} // namespace pir
Loading