Skip to content

Commit

Permalink
[PIR] add blocks interface for pir::Operation.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Nov 22, 2023
1 parent 98706c1 commit 5cedff4
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 40 deletions.
62 changes: 27 additions & 35 deletions paddle/fluid/pybind/control_flow_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,7 @@ using pir::YieldOp;
using pybind11::return_value_policy;

namespace {
class PyIfOp : public IfOp {
public:
explicit PyIfOp(IfOp if_op);
void UpdateOutput();
};

PyIfOp::PyIfOp(IfOp if_op) : IfOp(if_op) {
PADDLE_ENFORCE_NOT_NULL(
if_op,
paddle::platform::errors::InvalidArgument(
"The if_op used to construct PyIfOp can't be nullptr"));
}

void PyIfOp::UpdateOutput() {
PADDLE_ENFORCE_NOT_NULL(
*this,
paddle::platform::errors::InvalidArgument(
"The if_op in PyIfOp used to update output can't be nullptr"));
auto block = parent();
PADDLE_ENFORCE_NOT_NULL(block,
paddle::platform::errors::InvalidArgument(
"The parent block of if_op which used to update "
"output can't be nullptr"));
Block::Iterator iter = **this;
Builder builder(ir_context(), false);
auto new_if_op = builder.Build<IfOp>(
cond(), true_region().TakeBack(), false_region().TakeBack());
block->Assign(iter, new_if_op);
IfOp::operator=(new_if_op);
VerifyRegion();
}

PyIfOp BuildPyIfOp(Value cond) {
paddle::pybind::PyIfOp BuildPyIfOp(Value cond) {
return PyIfOp(ApiBuilder::Instance().GetBuilder()->Build<IfOp>(
cond, std::vector<Type>{}));
}
Expand Down Expand Up @@ -185,11 +153,35 @@ void BuildPipeForBlock(Block* block) {

namespace paddle {
namespace pybind {
PyIfOp::PyIfOp(IfOp if_op) : IfOp(if_op) {
PADDLE_ENFORCE_NOT_NULL(
if_op,
paddle::platform::errors::InvalidArgument(
"The if_op used to construct PyIfOp can't be nullptr"));
}

void PyIfOp::UpdateOutput() {
PADDLE_ENFORCE_NOT_NULL(
*this,
paddle::platform::errors::InvalidArgument(
"The if_op in PyIfOp used to update output can't be nullptr"));
auto block = parent();
PADDLE_ENFORCE_NOT_NULL(block,
paddle::platform::errors::InvalidArgument(
"The parent block of if_op which used to update "
"output can't be nullptr"));
Block::Iterator iter = **this;
Builder builder(ir_context(), false);
auto new_if_op = builder.Build<IfOp>(
cond(), true_region().TakeBack(), false_region().TakeBack());
block->Assign(iter, new_if_op);
IfOp::operator=(new_if_op);
VerifyRegion();
}

void BindControlFlowApi(py::module* m) {
m->def("get_used_external_value", GetUsedExternalValue);
m->def("build_pipe_for_block", BuildPipeForBlock);
m->def("cvt_as_if_op",
[](Operation& op) { return PyIfOp(op.dyn_cast<IfOp>()); });
BindIfOp(m);
}
} // namespace pybind
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/pybind/control_flow_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
#pragma once

#include <pybind11/pybind11.h>
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"

namespace paddle {
namespace pybind {
class PyIfOp : public dialect::IfOp {
public:
explicit PyIfOp(dialect::IfOp if_op);
void UpdateOutput();
};

void BindControlFlowApi(pybind11::module *m);
} // namespace pybind
} // namespace paddle
7 changes: 6 additions & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ namespace py = pybind11;
using paddle::dialect::ApiBuilder;
using paddle::dialect::DenseTensorArrayType;
using paddle::dialect::DenseTensorType;
using paddle::dialect::IfOp;
using paddle::dialect::SelectedRowsType;

using pir::Attribute;
using pir::Block;
using pir::Operation;
Expand Down Expand Up @@ -374,6 +376,7 @@ void BindOperation(py::module *m) {
.def("operand_source", &Operation::operand_source)
.def("operands", &Operation::operands)
.def("results", &Operation::results)
.def("blocks", [](Operation &self) { return self.blocks(); })
.def("attrs",
[](Operation &self) -> py::dict {
py::dict attrs_dict;
Expand Down Expand Up @@ -449,7 +452,9 @@ void BindOperation(py::module *m) {
.def("replace_all_uses_with",
[](Operation &self, const std::vector<OpResult> &op_results) {
self.ReplaceAllUsesWith(op_results);
});
})
.def("as_if_op",
[](Operation &self) { return PyIfOp(self.dyn_cast<IfOp>()); });
}

py::str Value2String(const Value &self) {
Expand Down
108 changes: 108 additions & 0 deletions paddle/pir/core/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <iterator>
#include <list>
#include "paddle/pir/core/macros.h"
namespace pir {

class Operation;
Expand Down Expand Up @@ -187,4 +188,111 @@ class PointerListConstIterator {
operator ElementType*() const { return *iterator_; }
};

///
/// \brief The Iterator used to flatten two-level containers into one level.
///
template <typename ContainerT>
class DoubleLevelContainer {
public:
class Iterator;
Iterator begin();
Iterator end();

protected:
// only support constructed by ConstainerT;
DoubleLevelContainer() = default;
DISABLE_COPY_AND_ASSIGN(DoubleLevelContainer);
const ContainerT& container() const {
return *static_cast<const ContainerT*>(this);
}
ContainerT& container() { return *static_cast<ContainerT*>(this); }
};
template <typename ContainerT>
class DoubleLevelContainer<ContainerT>::Iterator {
public:
using OuterIterator = typename ContainerT::Iterator;
using InnerIterator = typename ContainerT::Element::Iterator;
using Element = typename ContainerT::Element::Element;

Iterator() = default;
Iterator(const OuterIterator& outer_iter,
const OuterIterator& outer_end,
const InnerIterator& inner_iter)
: outer_iter_(outer_iter),
outer_end_(outer_end),
inner_iter_(inner_iter) {}

Element& operator*() const noexcept { return *inner_iter_; }

Element* operator->() const noexcept { return &this->operator*(); }

Iterator& operator++() noexcept {
++inner_iter_;
while (inner_iter_ == outer_iter_->end()) {
++outer_iter_;
if (outer_iter_ == outer_end_) break;
inner_iter_ = outer_iter_->begin();
}
return *this;
}
Iterator operator++(int) noexcept {
Iterator __tmp = *this;
++*this;
return __tmp;
}

Iterator& operator--() noexcept {
if (outer_iter_ == outer_end_) {
outer_iter_--;
inner_iter_ = outer_iter_->end();
}
while (inner_iter_ == outer_iter_->begin()) {
--outer_iter_;
inner_iter_ = outer_iter_->end();
}
--inner_iter_;
return *this;
}

Iterator operator--(int) noexcept {
Iterator __tmp = *this;
--*this;
return __tmp;
}

bool operator==(const Iterator& __x) const noexcept {
return outer_iter_ == __x.outer_iter_ &&
(outer_iter_ == outer_end_ || inner_iter_ == __x.inner_iter_);
}

bool operator!=(const Iterator& __x) const noexcept {
return !this->operator==(__x);
}

private:
OuterIterator outer_iter_, outer_end_;
InnerIterator inner_iter_;
};
template <typename ContainerT>
typename DoubleLevelContainer<ContainerT>::Iterator
DoubleLevelContainer<ContainerT>::begin() {
auto outer_iter = container().begin();
typename Iterator::InnerIterator inner_iter;
while (outer_iter != container().end()) {
if (outer_iter->empty()) {
++outer_iter;
} else {
inner_iter = outer_iter->begin();
break;
}
}
return Iterator(outer_iter, container().end(), inner_iter);
}

template <typename ContainerT>
typename DoubleLevelContainer<ContainerT>::Iterator
DoubleLevelContainer<ContainerT>::end() {
return Iterator(
container().end(), container().end(), typename Iterator::InnerIterator());
}
} // namespace pir
8 changes: 7 additions & 1 deletion paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <vector>
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/iterator.h"
#include "paddle/pir/core/macros.h"
#include "paddle/pir/core/op_info.h"
#include "paddle/pir/core/operation_utils.h"
Expand All @@ -34,7 +35,8 @@ class OpResultImpl;
class OpOperendImpl;
} // namespace detail

class IR_API alignas(8) Operation final {
class IR_API alignas(8) Operation final
: public DoubleLevelContainer<Operation> {
public:
///
/// \brief Malloc memory and construct objects in the following order:
Expand Down Expand Up @@ -109,6 +111,7 @@ class IR_API alignas(8) Operation final {
///
/// \brief region related public interfaces
///
using Element = Region;
using Iterator = Region *;
using ConstIterator = const Region *;
uint32_t num_regions() const { return num_regions_; }
Expand All @@ -119,6 +122,9 @@ class IR_API alignas(8) Operation final {
Iterator begin() { return regions_; }
Iterator end() { return regions_ + num_regions_; }

/// \brief block related public interfaces
DoubleLevelContainer<Operation> &blocks() { return *this; }

///
/// \brief parent related public interfaces
///
Expand Down
1 change: 1 addition & 0 deletions paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Program;

class IR_API Region {
public:
using Element = Block;
using Iterator = PointerListIterator<Block>;
using ConstIterator = PointerListConstIterator<Block>;
using ReverseIterator = std::reverse_iterator<Iterator>;
Expand Down
10 changes: 9 additions & 1 deletion test/cpp/pir/control_flow_dialect/if_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,22 @@ TEST(if_op_test, build_by_block) {

builder.SetInsertionPointToEnd(block);

builder.Build<paddle::dialect::IfOp>(
auto if_op = builder.Build<paddle::dialect::IfOp>(
full_op.out(), std::move(true_block), std::move(false_block));

EXPECT_FALSE(true_block);
EXPECT_FALSE(false_block);
EXPECT_EQ(full_op_2->GetParentProgram(), &program);

LOG(INFO) << program;

std::vector<pir::Block*> vec;
for (auto& block : if_op->blocks()) {
vec.push_back(&block);
}
EXPECT_EQ(vec.size(), 2u);
EXPECT_EQ(vec[0], if_op.true_block());
EXPECT_EQ(vec[1], if_op.false_block());
}

TEST(if_op_test, network_with_backward) {
Expand Down
3 changes: 1 addition & 2 deletions test/ir/pir/test_if_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import paddle
from paddle.base.libpaddle.pir import (
build_pipe_for_block,
cvt_as_if_op,
get_used_external_value,
)

Expand Down Expand Up @@ -60,7 +59,7 @@ def test_if_with_multiple_output(self):
out = paddle.static.nn.cond(pred, true_func, false_func)
self.assertEqual(out[0].get_defining_op().name(), "pd_op.if")
self.assertEqual(len(out), 2)
if_op = cvt_as_if_op(out[0].get_defining_op())
if_op = out[0].get_defining_op().as_if_op()
true_block = if_op.true_block()
self.assertEqual(len(true_block), 3)
build_pipe_for_block(true_block)
Expand Down

0 comments on commit 5cedff4

Please sign in to comment.