Skip to content

Commit

Permalink
[PIR] remove pir::Value:;GetDefinitionOp interface
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Sep 18, 2023
1 parent d481b4f commit 673b5b7
Show file tree
Hide file tree
Showing 19 changed files with 74 additions and 90 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1309,11 +1309,11 @@ void NewIRInterpreter::SolvePersisableVarNames() {
::pir::Value value = kv.first;
const std::string& var_name = kv.second;
::pir::OpResult result = value.dyn_cast<::pir::OpResult>();
auto* defining_op = value.GetDefiningOp();
auto* defining_op = result.owner();
if (defining_op->HasAttribute(kAttrIsPersisable)) {
auto is_persisables = defining_op->attribute(kAttrIsPersisable)
.dyn_cast<::pir::ArrayAttribute>()
.AsVector();
auto is_persisables =
defining_op->attribute<::pir::ArrayAttribute>(kAttrIsPersisable)
.AsVector();
if (is_persisables[result.index()]
.dyn_cast<::pir::BoolAttribute>()
.data()) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """
pir::CombineOp combine_op_obj =
op_obj.{input_name}().GetDefiningOp()->dyn_cast<pir::CombineOp>();
op_obj.{input_name}().dyn_cast<pir::OpResult>().owner()->dyn_cast<pir::CombineOp>();
std::vector<Tensor> {input_name};
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{
{input_name}.emplace_back(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ paddle::framework::Variable* CreateVar(
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
Operation* def_op = value.GetDefiningOp();
Operation* def_op = value.dyn_cast<OpResult>().owner();
bool is_persisable = false;
if (def_op->isa<::pir::SetParameterOp>()) {
is_persisable = true;
Expand Down
15 changes: 8 additions & 7 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ static bool CanBeDeleted(pir::Value value) {
!value.type().isa<paddle::dialect::AllocatedSelectedRowsType>()) {
return false;
}
if (value.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) {
return !(value.GetDefiningOp()
->attribute(kAttrIsPersisable)
.dyn_cast<pir::ArrayAttribute>()
.AsVector()[value.dyn_cast<pir::OpResult>().index()]
.dyn_cast<pir::BoolAttribute>()
.data());
if (auto op_result = value.dyn_cast<pir::OpResult>()) {
auto def_op = op_result.owner();
if (def_op->HasAttribute(kAttrIsPersisable)) {
return !(def_op->attribute<pir::ArrayAttribute>(kAttrIsPersisable)
.AsVector()[op_result.index()]
.dyn_cast<pir::BoolAttribute>()
.data());
}
}
return true;
}
Expand Down
48 changes: 25 additions & 23 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in,
pir::Operation* op =
pir::Operation::Create({in}, op_attribute, {out_type}, op_info);

if (in.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) {
if (in.owner()->HasAttribute(kAttrIsPersisable)) {
op->set_attribute(kAttrIsPersisable,
in.GetDefiningOp()->attribute(kAttrIsPersisable));
in.owner()->attribute(kAttrIsPersisable));
}
block->push_back(op);

Expand Down Expand Up @@ -527,11 +527,10 @@ phi::KernelKey GetKernelKey(
if (op->isa<paddle::dialect::UniformOp>()) {
// try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op
auto define_op = op->operand_source(0).GetDefiningOp();
auto define_op =
op->operand_source(0).dyn_cast<pir::OpResult>().owner();
if (define_op->isa<paddle::dialect::FullIntArrayOp>()) {
auto shape = define_op->attributes()
.at("value")
.dyn_cast<dialect::IntArrayAttribute>()
auto shape = define_op->attribute<dialect::IntArrayAttribute>("value")
.data()
.GetData();

Expand Down Expand Up @@ -577,13 +576,12 @@ phi::KernelKey GetKernelKey(
// uses data op outout as inputs. So, we need set kernel backend
// manually.
if (op->operand_source(i)
.GetDefiningOp()
.dyn_cast<pir::OpResult>()
.owner()
->isa<paddle::dialect::DataOp>()) {
auto data_op = op->operand_source(i).GetDefiningOp();
auto data_place = data_op->attributes()
.at("place")
.dyn_cast<dialect::PlaceAttribute>()
.data();
auto data_op = op->operand_source(i).dyn_cast<pir::OpResult>().owner();
auto data_place =
data_op->attribute<dialect::PlaceAttribute>("place").data();

auto data_op_backend = paddle::experimental::ParseBackend(data_place);
if (data_op_backend == phi::Backend::UNDEFINED) {
Expand All @@ -592,17 +590,21 @@ phi::KernelKey GetKernelKey(
kernel_key_parser.key_set.backend_set =
kernel_key_parser.key_set.backend_set |
paddle::experimental::BackendSet(data_op_backend);
} else if (op->operand_source(i).GetDefiningOp()->name() ==
"builtin.combine") {
auto combine_op = op->operand_source(i).GetDefiningOp();
} else if (op->operand_source(i)
.dyn_cast<pir::OpResult>()
.owner()
->isa<pir::CombineOp>()) {
auto combine_op =
op->operand_source(i).dyn_cast<pir::OpResult>().owner();
for (size_t j = 0; j < combine_op->num_operands(); ++j) {
if (combine_op->operand_source(j).GetDefiningOp()->name() ==
"pd_op.data") {
auto data_op = combine_op->operand_source(j).GetDefiningOp();
auto data_place = data_op->attributes()
.at("place")
.dyn_cast<dialect::PlaceAttribute>()
.data();
if (combine_op->operand_source(j)
.dyn_cast<pir::OpResult>()
.owner()
->isa<DataOp>()) {
auto data_op =
combine_op->operand_source(j).dyn_cast<pir::OpResult>().owner();
auto data_place =
data_op->attribute<PlaceAttribute>("place").data();

auto data_op_backend =
paddle::experimental::ParseBackend(data_place);
Expand Down Expand Up @@ -981,7 +983,7 @@ std::vector<pir::Value> BuildOpInputList(
} else if (new_in_type.isa<pir::VectorType>()) {
// [ todo need update here, support combine data transfomer]
// deal with pre combine op
auto pre_define_op = cur_in.GetDefiningOp();
auto pre_define_op = cur_in.dyn_cast<pir::OpResult>().owner();

if (pre_define_op->isa<::pir::CombineOp>()) {
std::vector<pir::Value> inner_inputs;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/transform_general_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace pir {
std::pair<std::string, pir::Parameter*> GetParameterFromValue(
pir::Value value) {
pir::GetParameterOp op =
value.GetDefiningOp()->dyn_cast<pir::GetParameterOp>();
value.dyn_cast<OpResult>().owner()->dyn_cast<pir::GetParameterOp>();
PADDLE_ENFORCE_NOT_NULL(
op,
phi::errors::InvalidArgument(
Expand Down Expand Up @@ -66,7 +66,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand_source(index).GetDefiningOp();
return op->operand_source(index).dyn_cast<OpResult>().owner();
}

Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ return vjp_res;
{% macro get_mutable_attribute(attrs, api_name) %}
{% for i in attrs %}
{%- if i is mutable_attribute -%}
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->value().dyn_cast<pir::OpResult>().GetDefiningOp();
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->value().dyn_cast<pir::OpResult>().owner();
{% if i.typename is scalar %}
if({{i.name}}_define_op->name() != "pd_op.full") {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down
26 changes: 13 additions & 13 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,12 @@ void BindValue(py::module *m) {
)DOC");
value
.def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def(
"get_defining_op",
[](const Value &self) {
return self.dyn_cast<pir::OpResult>().owner();
},
return_value_policy::reference)
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &Value::use_empty)
Expand Down Expand Up @@ -444,22 +447,19 @@ void BindOpResult(py::module *m) {
})
.def("__hash__",
[](OpResult &self) { return std::hash<pir::Value>{}(self); })
.def("get_defining_op",
&OpResult::GetDefiningOp,
return_value_policy::reference)
.def("get_defining_op", &OpResult::owner, return_value_policy::reference)
.def_property_readonly(
"block",
[](OpResult &self) { return self.GetDefiningOp()->GetParent(); },
[](OpResult &self) { return self.owner()->GetParent(); },
return_value_policy::reference)
.def_property_readonly(
"name",
[](OpResult &self) {
if (self.GetDefiningOp()->isa<::pir::GetParameterOp>()) {
auto param_name = self.GetDefiningOp()
->attributes()
.at("parameter_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
if (self.owner()->isa<::pir::GetParameterOp>()) {
auto param_name =
self.owner()
->attribute<pir::StrAttribute>("parameter_name")
.AsString();
return param_name;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
Expand Down
6 changes: 3 additions & 3 deletions paddle/pir/core/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class IR_API Attribute {

bool operator!() const { return storage_ == nullptr; }

operator const void *() const { return storage_; }

///
/// \brief Some Attribute attribute acquisition interfaces.
///
Expand Down Expand Up @@ -85,8 +87,6 @@ class IR_API Attribute {
return pir::dyn_cast<U>(*this);
}

friend struct std::hash<Attribute>;

protected:
const Storage *storage_{nullptr};
};
Expand All @@ -98,7 +98,7 @@ namespace std {
template <>
struct hash<pir::Attribute> {
std::size_t operator()(const pir::Attribute &obj) const {
return std::hash<const pir::Attribute::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
2 changes: 1 addition & 1 deletion paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
for (uint32_t i = 0; i < defining_op->num_operands(); ++i) {
auto value = defining_op->operand_source(i);
if (!value) continue;
auto *oprand_defining_op = value.GetDefiningOp();
auto *oprand_defining_op = value.dyn_cast<OpResult>().owner();
if (oprand_defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = oprand_defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/op_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ class IR_API OpInfo {
template <typename InterfaceT>
typename InterfaceT::Concept *GetInterfaceImpl() const;

operator const void *() const { return impl_; }
void *AsOpaquePointer() const { return impl_; }
static OpInfo RecoverFromOpaquePointer(void *pointer) {
return OpInfo(static_cast<OpInfoImpl *>(pointer));
}

friend class OpInfoImpl;
friend struct std::hash<OpInfo>;

private:
explicit OpInfo(OpInfoImpl *impl) : impl_(impl) {}
Expand Down Expand Up @@ -105,7 +105,7 @@ namespace std {
template <>
struct hash<pir::OpInfo> {
std::size_t operator()(const pir::OpInfo &obj) const {
return std::hash<const pir::OpInfoImpl *>()(obj.impl_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
15 changes: 3 additions & 12 deletions paddle/pir/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class IR_API Type {
Type() = default;

Type(const Storage *storage) // NOLINT
: storage_(const_cast<Storage *>(storage)) {}
: storage_(storage) {}

Type(const Type &other) = default;

Expand All @@ -73,11 +73,7 @@ class IR_API Type {
///
/// \brief Support PointerLikeTypeTraits.
///
///
const void *AsOpaquePointer() const {
return static_cast<const void *>(storage_);
}

operator const void *() const { return storage_; }
static Type RecoverFromOpaquePointer(const void *pointer) {
return Type(reinterpret_cast<Storage *>(const_cast<void *>(pointer)));
}
Expand Down Expand Up @@ -120,11 +116,6 @@ class IR_API Type {

static Type Parse(std::istream &is, IrContext *ctx);

///
/// \brief Enable hashing Type.
///
friend struct std::hash<Type>;

template <typename U>
U cast() const {
return pir::cast<U>(*this);
Expand Down Expand Up @@ -189,7 +180,7 @@ namespace std {
template <>
struct hash<pir::Type> {
std::size_t operator()(const pir::Type &obj) const {
return std::hash<const pir::Type::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
8 changes: 2 additions & 6 deletions paddle/pir/core/type_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TypeId {
///
/// \brief Support PointerLikeTypeTraits.
///
operator const void *() const { return storage_; }
void *AsOpaquePointer() const { return storage_; }
static TypeId RecoverFromOpaquePointer(void *pointer) {
return TypeId(static_cast<Storage *>(pointer));
Expand All @@ -71,11 +72,6 @@ class TypeId {
return storage_ < other.storage_;
}

///
/// \brief Enable hashing TypeId instances.
///
friend struct std::hash<TypeId>;

private:
///
/// \brief Construct a TypeId and initialize storage.
Expand Down Expand Up @@ -150,7 +146,7 @@ namespace std {
template <>
struct hash<pir::TypeId> {
std::size_t operator()(const pir::TypeId &obj) const {
return std::hash<const pir::TypeId::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
5 changes: 0 additions & 5 deletions paddle/pir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ void Value::set_type(pir::Type type) {
impl_->set_type(type);
}

Operation *Value::GetDefiningOp() const {
if (auto result = dyn_cast<OpResult>()) return result.owner();
return nullptr;
}

std::string Value::PrintUdChain() {
CHECK_VALUE_NULL_IMPL(PrintUdChain);
return impl()->PrintUdChain();
Expand Down
2 changes: 0 additions & 2 deletions paddle/pir/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class IR_API Value {

void set_type(Type type);

Operation *GetDefiningOp() const;

std::string PrintUdChain();

///
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/dialect/shape/utils/shape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bool SymbolicDimMgr::LoadShapeConstraintGraph() {
auto build_sym_product = [&](std::vector<Value> range,
SymbolicDimProduct& product) {
for (Value v : range) {
auto definingOp = v.GetDefiningOp();
auto definingOp = v.dyn_cast<OpResult>().owner();
if (auto constOp = definingOp->dyn_cast<ConstantOp>()) {
product.factor *= constOp.value().dyn_cast<Int32Attribute>().data();
continue;
Expand Down
3 changes: 2 additions & 1 deletion paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ class GreedyPatternRewriteDriver : public pir::PatternRewriter {
// that single use values often have more canonicalization opportunities.
if (!operand || (!operand.use_empty() && !operand.HasOneUse())) return;

if (auto* def_op = operand.GetDefiningOp()) AddToWorklist(def_op);
if (auto* def_op = operand.dyn_cast<pir::OpResult>().owner())
AddToWorklist(def_op);
}

void AddOperandsToWorklist(const std::vector<pir::Value> operands) {
Expand Down
Loading

0 comments on commit 673b5b7

Please sign in to comment.