Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] standardize the use of value[-5]. #57461

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1309,11 +1309,11 @@ void NewIRInterpreter::SolvePersisableVarNames() {
::pir::Value value = kv.first;
const std::string& var_name = kv.second;
::pir::OpResult result = value.dyn_cast<::pir::OpResult>();
auto* defining_op = value.GetDefiningOp();
auto* defining_op = result.owner();
if (defining_op->HasAttribute(kAttrIsPersisable)) {
auto is_persisables = defining_op->attribute(kAttrIsPersisable)
.dyn_cast<::pir::ArrayAttribute>()
.AsVector();
auto is_persisables =
defining_op->attribute<::pir::ArrayAttribute>(kAttrIsPersisable)
.AsVector();
if (is_persisables[result.index()]
.dyn_cast<::pir::BoolAttribute>()
.data()) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """
pir::CombineOp combine_op_obj =
op_obj.{input_name}().GetDefiningOp()->dyn_cast<pir::CombineOp>();
op_obj.{input_name}().dyn_cast<pir::OpResult>().owner()->dyn_cast<pir::CombineOp>();
std::vector<Tensor> {input_name};
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{
{input_name}.emplace_back(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ paddle::framework::Variable* CreateVar(
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
Operation* def_op = value.GetDefiningOp();
Operation* def_op = value.dyn_cast<OpResult>().owner();
bool is_persisable = false;
if (def_op->isa<::pir::SetParameterOp>()) {
is_persisable = true;
Expand Down
15 changes: 8 additions & 7 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ static bool CanBeDeleted(pir::Value value) {
!value.type().isa<paddle::dialect::AllocatedSelectedRowsType>()) {
return false;
}
if (value.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) {
return !(value.GetDefiningOp()
->attribute(kAttrIsPersisable)
.dyn_cast<pir::ArrayAttribute>()
.AsVector()[value.dyn_cast<pir::OpResult>().index()]
.dyn_cast<pir::BoolAttribute>()
.data());
if (auto op_result = value.dyn_cast<pir::OpResult>()) {
auto def_op = op_result.owner();
if (def_op->HasAttribute(kAttrIsPersisable)) {
return !(def_op->attribute<pir::ArrayAttribute>(kAttrIsPersisable)
.AsVector()[op_result.index()]
.dyn_cast<pir::BoolAttribute>()
.data());
}
}
return true;
}
Expand Down
48 changes: 25 additions & 23 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in,
pir::Operation* op =
pir::Operation::Create({in}, op_attribute, {out_type}, op_info);

if (in.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) {
if (in.owner()->HasAttribute(kAttrIsPersisable)) {
op->set_attribute(kAttrIsPersisable,
in.GetDefiningOp()->attribute(kAttrIsPersisable));
in.owner()->attribute(kAttrIsPersisable));
}
block->push_back(op);

Expand Down Expand Up @@ -527,11 +527,10 @@ phi::KernelKey GetKernelKey(
if (op->isa<paddle::dialect::UniformOp>()) {
// try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op
auto define_op = op->operand_source(0).GetDefiningOp();
auto define_op =
op->operand_source(0).dyn_cast<pir::OpResult>().owner();
if (define_op->isa<paddle::dialect::FullIntArrayOp>()) {
auto shape = define_op->attributes()
.at("value")
.dyn_cast<dialect::IntArrayAttribute>()
auto shape = define_op->attribute<dialect::IntArrayAttribute>("value")
.data()
.GetData();

Expand Down Expand Up @@ -577,13 +576,12 @@ phi::KernelKey GetKernelKey(
// uses data op outout as inputs. So, we need set kernel backend
// manually.
if (op->operand_source(i)
.GetDefiningOp()
.dyn_cast<pir::OpResult>()
.owner()
->isa<paddle::dialect::DataOp>()) {
auto data_op = op->operand_source(i).GetDefiningOp();
auto data_place = data_op->attributes()
.at("place")
.dyn_cast<dialect::PlaceAttribute>()
.data();
auto data_op = op->operand_source(i).dyn_cast<pir::OpResult>().owner();
auto data_place =
data_op->attribute<dialect::PlaceAttribute>("place").data();

auto data_op_backend = paddle::experimental::ParseBackend(data_place);
if (data_op_backend == phi::Backend::UNDEFINED) {
Expand All @@ -592,17 +590,21 @@ phi::KernelKey GetKernelKey(
kernel_key_parser.key_set.backend_set =
kernel_key_parser.key_set.backend_set |
paddle::experimental::BackendSet(data_op_backend);
} else if (op->operand_source(i).GetDefiningOp()->name() ==
"builtin.combine") {
auto combine_op = op->operand_source(i).GetDefiningOp();
} else if (op->operand_source(i)
.dyn_cast<pir::OpResult>()
.owner()
->isa<pir::CombineOp>()) {
auto combine_op =
op->operand_source(i).dyn_cast<pir::OpResult>().owner();
for (size_t j = 0; j < combine_op->num_operands(); ++j) {
if (combine_op->operand_source(j).GetDefiningOp()->name() ==
"pd_op.data") {
auto data_op = combine_op->operand_source(j).GetDefiningOp();
auto data_place = data_op->attributes()
.at("place")
.dyn_cast<dialect::PlaceAttribute>()
.data();
if (combine_op->operand_source(j)
.dyn_cast<pir::OpResult>()
.owner()
->isa<DataOp>()) {
auto data_op =
combine_op->operand_source(j).dyn_cast<pir::OpResult>().owner();
auto data_place =
data_op->attribute<PlaceAttribute>("place").data();

auto data_op_backend =
paddle::experimental::ParseBackend(data_place);
Expand Down Expand Up @@ -981,7 +983,7 @@ std::vector<pir::Value> BuildOpInputList(
} else if (new_in_type.isa<pir::VectorType>()) {
// [ todo need update here, support combine data transfomer]
// deal with pre combine op
auto pre_define_op = cur_in.GetDefiningOp();
auto pre_define_op = cur_in.dyn_cast<pir::OpResult>().owner();

if (pre_define_op->isa<::pir::CombineOp>()) {
std::vector<pir::Value> inner_inputs;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/transform_general_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace pir {
std::pair<std::string, pir::Parameter*> GetParameterFromValue(
pir::Value value) {
pir::GetParameterOp op =
value.GetDefiningOp()->dyn_cast<pir::GetParameterOp>();
value.dyn_cast<OpResult>().owner()->dyn_cast<pir::GetParameterOp>();
PADDLE_ENFORCE_NOT_NULL(
op,
phi::errors::InvalidArgument(
Expand Down Expand Up @@ -66,7 +66,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand_source(index).GetDefiningOp();
return op->operand_source(index).dyn_cast<OpResult>().owner();
}

Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ return vjp_res;
{% macro get_mutable_attribute(attrs, api_name) %}
{% for i in attrs %}
{%- if i is mutable_attribute -%}
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->value().dyn_cast<pir::OpResult>().GetDefiningOp();
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->value().dyn_cast<pir::OpResult>().owner();
{% if i.typename is scalar %}
if({{i.name}}_define_op->name() != "pd_op.full") {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down
26 changes: 13 additions & 13 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,12 @@ void BindValue(py::module *m) {

)DOC");
value
.def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def(
"get_defining_op",
[](const Value &self) {
return self.dyn_cast<pir::OpResult>().owner();
},
return_value_policy::reference)
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("has_one_use", &Value::HasOneUse)
.def("use_empty", &Value::use_empty)
Expand Down Expand Up @@ -444,22 +447,19 @@ void BindOpResult(py::module *m) {
})
.def("__hash__",
[](OpResult &self) { return std::hash<pir::Value>{}(self); })
.def("get_defining_op",
&OpResult::GetDefiningOp,
return_value_policy::reference)
.def("get_defining_op", &OpResult::owner, return_value_policy::reference)
.def_property_readonly(
"block",
[](OpResult &self) { return self.GetDefiningOp()->GetParent(); },
[](OpResult &self) { return self.owner()->GetParent(); },
return_value_policy::reference)
.def_property_readonly(
"name",
[](OpResult &self) {
if (self.GetDefiningOp()->isa<::pir::GetParameterOp>()) {
auto param_name = self.GetDefiningOp()
->attributes()
.at("parameter_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
if (self.owner()->isa<::pir::GetParameterOp>()) {
auto param_name =
self.owner()
->attribute<pir::StrAttribute>("parameter_name")
.AsString();
return param_name;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
Expand Down
6 changes: 3 additions & 3 deletions paddle/pir/core/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class IR_API Attribute {

bool operator!() const { return storage_ == nullptr; }

operator const void *() const { return storage_; }

///
/// \brief Some Attribute attribute acquisition interfaces.
///
Expand Down Expand Up @@ -85,8 +87,6 @@ class IR_API Attribute {
return pir::dyn_cast<U>(*this);
}

friend struct std::hash<Attribute>;

protected:
const Storage *storage_{nullptr};
};
Expand All @@ -98,7 +98,7 @@ namespace std {
template <>
struct hash<pir::Attribute> {
std::size_t operator()(const pir::Attribute &obj) const {
return std::hash<const pir::Attribute::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
2 changes: 1 addition & 1 deletion paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
for (uint32_t i = 0; i < defining_op->num_operands(); ++i) {
auto value = defining_op->operand_source(i);
if (!value) continue;
auto *oprand_defining_op = value.GetDefiningOp();
auto *oprand_defining_op = value.dyn_cast<OpResult>().owner();
if (oprand_defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = oprand_defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/op_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ class IR_API OpInfo {
template <typename InterfaceT>
typename InterfaceT::Concept *GetInterfaceImpl() const;

operator const void *() const { return impl_; }
void *AsOpaquePointer() const { return impl_; }
static OpInfo RecoverFromOpaquePointer(void *pointer) {
return OpInfo(static_cast<OpInfoImpl *>(pointer));
}

friend class OpInfoImpl;
friend struct std::hash<OpInfo>;

private:
explicit OpInfo(OpInfoImpl *impl) : impl_(impl) {}
Expand Down Expand Up @@ -105,7 +105,7 @@ namespace std {
template <>
struct hash<pir::OpInfo> {
std::size_t operator()(const pir::OpInfo &obj) const {
return std::hash<const pir::OpInfoImpl *>()(obj.impl_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
15 changes: 3 additions & 12 deletions paddle/pir/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class IR_API Type {
Type() = default;

Type(const Storage *storage) // NOLINT
: storage_(const_cast<Storage *>(storage)) {}
: storage_(storage) {}

Type(const Type &other) = default;

Expand All @@ -73,11 +73,7 @@ class IR_API Type {
///
/// \brief Support PointerLikeTypeTraits.
///
///
const void *AsOpaquePointer() const {
return static_cast<const void *>(storage_);
}

operator const void *() const { return storage_; }
static Type RecoverFromOpaquePointer(const void *pointer) {
return Type(reinterpret_cast<Storage *>(const_cast<void *>(pointer)));
}
Expand Down Expand Up @@ -120,11 +116,6 @@ class IR_API Type {

static Type Parse(std::istream &is, IrContext *ctx);

///
/// \brief Enable hashing Type.
///
friend struct std::hash<Type>;

template <typename U>
U cast() const {
return pir::cast<U>(*this);
Expand Down Expand Up @@ -189,7 +180,7 @@ namespace std {
template <>
struct hash<pir::Type> {
std::size_t operator()(const pir::Type &obj) const {
return std::hash<const pir::Type::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
8 changes: 2 additions & 6 deletions paddle/pir/core/type_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TypeId {
///
/// \brief Support PointerLikeTypeTraits.
///
operator const void *() const { return storage_; }
void *AsOpaquePointer() const { return storage_; }
static TypeId RecoverFromOpaquePointer(void *pointer) {
return TypeId(static_cast<Storage *>(pointer));
Expand All @@ -71,11 +72,6 @@ class TypeId {
return storage_ < other.storage_;
}

///
/// \brief Enable hashing TypeId instances.
///
friend struct std::hash<TypeId>;

private:
///
/// \brief Construct a TypeId and initialize storage.
Expand Down Expand Up @@ -150,7 +146,7 @@ namespace std {
template <>
struct hash<pir::TypeId> {
std::size_t operator()(const pir::TypeId &obj) const {
return std::hash<const pir::TypeId::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
5 changes: 0 additions & 5 deletions paddle/pir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ void Value::set_type(pir::Type type) {
impl_->set_type(type);
}

Operation *Value::GetDefiningOp() const {
if (auto result = dyn_cast<OpResult>()) return result.owner();
return nullptr;
}

std::string Value::PrintUdChain() {
CHECK_VALUE_NULL_IMPL(PrintUdChain);
return impl()->PrintUdChain();
Expand Down
2 changes: 0 additions & 2 deletions paddle/pir/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class IR_API Value {

void set_type(Type type);

Operation *GetDefiningOp() const;

std::string PrintUdChain();

///
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/dialect/shape/utils/shape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bool SymbolicDimMgr::LoadShapeConstraintGraph() {
auto build_sym_product = [&](std::vector<Value> range,
SymbolicDimProduct& product) {
for (Value v : range) {
auto definingOp = v.GetDefiningOp();
auto definingOp = v.dyn_cast<OpResult>().owner();
if (auto constOp = definingOp->dyn_cast<ConstantOp>()) {
product.factor *= constOp.value().dyn_cast<Int32Attribute>().data();
continue;
Expand Down
3 changes: 2 additions & 1 deletion paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ class GreedyPatternRewriteDriver : public pir::PatternRewriter {
// that single use values often have more canonicalization opportunities.
if (!operand || (!operand.use_empty() && !operand.HasOneUse())) return;

if (auto* def_op = operand.GetDefiningOp()) AddToWorklist(def_op);
if (auto* def_op = operand.dyn_cast<pir::OpResult>().owner())
AddToWorklist(def_op);
}

void AddOperandsToWorklist(const std::vector<pir::Value> operands) {
Expand Down
Loading