Skip to content

Commit

Permalink
[PIR] standardize the use of value.[4-4] (PaddlePaddle#57373)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Sep 18, 2023
1 parent 169d63c commit a6459cd
Show file tree
Hide file tree
Showing 25 changed files with 180 additions and 183 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
phi::DataType::INT64,
phi::CPUPlace());
auto y_true_shape_op = builder.Build<pir::CombineOp>(
std::vector<pir::OpResult>{shape_op.out(), append_shape_op.out()});
std::vector<pir::Value>{shape_op.out(), append_shape_op.out()});
auto concat_op =
builder.Build<dialect::ConcatOp>(y_true_shape_op.out(), 0);
auto y_new_shape = concat_op.out();
Expand Down
21 changes: 11 additions & 10 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
COMPUTE_OP_TEMPLATE = """
paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""

OP_RESULT = 'pir::OpResult'
OP_INPUT = 'pir::Value'
VECTOR_TYPE = 'pir::VectorType'
INTARRAY_ATTRIBUTE = "paddle::dialect::IntArrayAttribute"

Expand All @@ -96,6 +96,11 @@ def get_op_class_name(op_name):
class CodeGen:
def __init__(self) -> None:
self._type_map = {
'paddle::dialect::DenseTensorType': 'pir::Value',
'paddle::dialect::SelectedRowsType': 'pir::Value',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::Value>',
}
self._ret_type_map = {
'paddle::dialect::DenseTensorType': 'pir::OpResult',
'paddle::dialect::SelectedRowsType': 'pir::OpResult',
'pir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<pir::OpResult>',
Expand Down Expand Up @@ -160,18 +165,14 @@ def _gen_api_attrs(
== INTARRAY_ATTRIBUTE
and is_vector_mutable_attr
):
mutable_attr.append(f'std::vector<{OP_RESULT}> {name}')
mutable_attr.append(f'std::vector<{OP_INPUT}> {name}')
else:
mutable_attr.append(f'{OP_RESULT} {name}')
mutable_attr.append(f'{OP_INPUT} {name}')
continue
if with_default and default_value is not None:
if type in ['float', 'double']:
default_value = default_value.strip('"')
no_mutable_attr.append(
'{type} {name} = {default_value}'.format(
type=type, name=name, default_value=default_value
)
)
no_mutable_attr.append(f'{type} {name} = {default_value}')
else:
no_mutable_attr.append(f'{type} {name}')
return ', '.join(mutable_attr + no_mutable_attr)
Expand Down Expand Up @@ -199,7 +200,7 @@ def _gen_ret_type(self, op_info):
return 'std::tuple<{}>'.format(
', '.join(
[
self._type_map[type]
self._ret_type_map[type]
for type, intermediate in zip(
type_list, intermediate_list
)
Expand All @@ -209,7 +210,7 @@ def _gen_ret_type(self, op_info):
)
elif output_num == 1:
index = intermediate_list.index('false')
return self._type_map[type_list[index]]
return self._ret_type_map[type_list[index]]
elif output_num == 0:
return 'void'

Expand Down
43 changes: 18 additions & 25 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def GenBuildInputArgsStr(
attr_args_is_map=False,
):
'''
Example: pir::Builder &builder, pir::OperationArgument &argument, pir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
Example: pir::Builder &builder, pir::OperationArgument &argument, pir::Value x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
'''
# add inputs
build_args_str = "pir::Builder &builder, pir::OperationArgument &argument"
if len(op_input_name_list) > 0:
for input_name in op_input_name_list:
build_args_str += ", pir::OpResult " + input_name + "_"
build_args_str += ", pir::Value " + input_name + "_"

if attr_args_is_map:
build_args_str += ", pir::AttributeMap attributes"
Expand Down Expand Up @@ -92,7 +92,7 @@ def GenBuildInputArgsStr(
# add mutable attributes as inputs
if len(op_mutable_attribute_name_list) > 0:
for mutable_attr in op_mutable_attribute_name_list:
build_args_str += ", pir::OpResult " + mutable_attr + "_"
build_args_str += ", pir::Value " + mutable_attr + "_"

# add non-mutable attributes
for attr_idx in range(len(op_non_mutable_attribute_name_list)):
Expand Down Expand Up @@ -183,8 +183,8 @@ def GenBuildInserFullForMutableAttribute(


def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list):
BUILD_INPUT_TEMPLATE = """ std::vector<pir::OpResult> argument_inputs = {{{inputs_args}}};
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());
BUILD_INPUT_TEMPLATE = """ std::vector<pir::Value> argument_inputs = {{{inputs_args}}};
argument.AddInputs(argument_inputs);
"""
build_input_str = ' VLOG(4) << "Builder construction inputs";\n'
input_name_list = op_input_name_list + op_mutable_attribute_name_list
Expand Down Expand Up @@ -344,16 +344,11 @@ def GenBuildOutputs(
meta_{name}.push_back(&vec_meta_{name}[i]);
}}
"""

CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""

CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name};
if ({name}_.owner()->info().id() == pir::TypeId::get<paddle::dialect::FullIntArrayOp>()) {{
{name} = std::move(phi::IntArray({name}_.owner()
if ({name}_.dyn_cast<pir::OpResult>().owner()->isa<paddle::dialect::FullIntArrayOp>()) {{
{name} = std::move(phi::IntArray({name}_.dyn_cast<pir::OpResult>().owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attributes()
.at("value")
.attribute("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData()));
Expand All @@ -370,14 +365,13 @@ def GenBuildOutputs(
}}\n"""

CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector<int64_t> {name};
if ({name}_.owner()->info().id() == pir::TypeId::get<paddle::dialect::FullIntArrayOp>()) {{
{name} = {name}_.owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
if ({name}_.dyn_cast<pir::OpResult>().owner()->isa<paddle::dialect::FullIntArrayOp>()) {{
{name} = {name}_.dyn_cast<pir::OpResult>().owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attribute("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
}} else if ({name}_.type().isa<pir::VectorType>()) {{
size_t {name}_size = {name}_.type().dyn_cast<pir::VectorType>().size();
{name} = std::vector<int64_t>({name}_size, -1);
Expand All @@ -389,11 +383,10 @@ def GenBuildOutputs(
}}\n"""

CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name};
if ({name}_.owner()->info().id() == pir::TypeId::get<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.owner()
if ({name}_.dyn_cast<pir::OpResult>().owner()->isa<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.dyn_cast<pir::OpResult>().owner()
->dyn_cast<paddle::dialect::FullOp>()
.attributes()
.at("value")
.attribute("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int>()));
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
}"""

OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<pir::OpResult>> {op_class_name}::Vjp(pir::Operation* op, const std::vector<std::vector<pir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
std::vector<std::vector<pir::OpResult>> {op_class_name}::Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); (void)op_obj;
VLOG(6) << "Prepare inputs of {op_grad_name}";
Expand Down Expand Up @@ -254,5 +254,5 @@ def gen_exclusive_interface_str(op_info):
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
return exclusive_interface_str
16 changes: 8 additions & 8 deletions paddle/fluid/pir/dialect/op_generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from api_gen import (
INTARRAY_ATTRIBUTE,
NAMESPACE_TEMPLATE,
OP_RESULT,
OP_INPUT,
VECTOR_TYPE,
CodeGen,
)
Expand Down Expand Up @@ -64,7 +64,7 @@
VLOG(6) << "Add {api_name} op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
// Get Value from args
{inputs}
// Parse Attributes
Expand All @@ -87,7 +87,7 @@
VLOG(6) << "Add {api_name} op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
// Get Value from args
{inputs}
// Parse Attributes
Expand Down Expand Up @@ -118,7 +118,7 @@
VLOG(6) << "Add {api_name} op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
// Get Value from args
{inputs}
// Parse Attributes
Expand Down Expand Up @@ -245,7 +245,7 @@ def _gen_inputs(self, op_info, op_name):
ret = ''
for i, (name, type) in enumerate(zip(name_list, type_list)):
cast_func = (
'CastPyArg2VectorOfOpResult'
'CastPyArg2VectorOfValue'
if VECTOR_TYPE in type
else 'CastPyArg2OpResult'
)
Expand Down Expand Up @@ -286,7 +286,7 @@ def _gen_init_mutable_attrs(self, op_info):
mutable_attr_name_list = op_info.mutable_attribute_name_list
ret = ''
for name in mutable_attr_name_list:
ret += INIT_ATTRS_TEMPLATE.format(type=OP_RESULT, name=name)
ret += INIT_ATTRS_TEMPLATE.format(type=OP_INPUT, name=name)

return ret

Expand All @@ -311,10 +311,10 @@ def _gen_cast_attrs(self, op_info, op_name):
== INTARRAY_ATTRIBUTE
):
mutable_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format(
type='std::vector<pir::OpResult>',
type='std::vector<pir::Value>',
name_=name + '_tmp',
name=name,
cast_func='CastPyArg2VectorOfOpResult',
cast_func='CastPyArg2VectorOfValue',
api_name=op_name,
index=input_size + i,
)
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/pir/dialect/operator/interface/interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
namespace paddle {
namespace dialect {
std::vector<std::vector<pir::OpResult>> VjpInterface::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<pir::Value>> out_grads_value;
for (const auto& grad : out_grads) {
std::vector<pir::Value> grad_value;
for (auto op_result : grad) {
grad_value.emplace_back(op_result);
}
out_grads_value.emplace_back(std::move(grad_value));
}
return impl_->vjp_(op, out_grads_value, stop_gradients);
}
} // namespace dialect
} // namespace paddle

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferMetaInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface)
Expand Down
13 changes: 9 additions & 4 deletions paddle/fluid/pir/dialect/operator/interface/vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {
struct Concept {
explicit Concept(std::vector<std::vector<pir::OpResult>> (*vjp)(
pir::Operation* op,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients))
: vjp_(vjp) {}
std::vector<std::vector<pir::OpResult>> (*vjp_)(
pir::Operation* op,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);
};

template <class ConcreteOp>
struct Model : public Concept {
static std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
return ConcreteOp::Vjp(op, out_grads, stop_gradients);
}
Expand All @@ -48,11 +48,16 @@ class VjpInterface : public pir::OpInterfaceBase<VjpInterface> {

std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
return impl_->vjp_(op, out_grads, stop_gradients);
}

std::vector<std::vector<pir::OpResult>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);

private:
Concept* impl_;
};
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
namespace paddle {
namespace dialect {

pir::OpResult builtin_combine(std::vector<pir::OpResult> x) {
pir::OpResult builtin_combine(const std::vector<pir::Value>& x) {
auto combine_op =
APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(x);
return combine_op.out();
}

pir::OpResult zeros_like(pir::OpResult x,
pir::OpResult zeros_like(pir::Value x,
phi::DataType dtype,
const Place& place) {
return paddle::dialect::full_like(x, 0, dtype, place);
Expand All @@ -52,14 +52,14 @@ pir::OpResult get_parameter(const std::string& name,
return get_parameter_op.result(0);
}

void set_parameter(pir::OpResult parameter, const std::string& name) {
void set_parameter(pir::Value parameter, const std::string& name) {
APIBuilder::Instance().GetBuilder()->Build<pir::SetParameterOp>(parameter,
name);
}

pir::OpResult embedding_grad(pir::OpResult x,
pir::OpResult weight,
pir::OpResult out_grad,
pir::OpResult embedding_grad(pir::Value x,
pir::Value weight,
pir::Value out_grad,
int64_t padding_idx,
bool sparse) {
if (weight.type().isa<paddle::dialect::DenseTensorType>()) {
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@
namespace paddle {
namespace dialect {

pir::OpResult builtin_combine(std::vector<pir::OpResult> x);
pir::OpResult builtin_combine(const std::vector<pir::Value>& x);

pir::OpResult zeros_like(pir::OpResult x,
pir::OpResult zeros_like(pir::Value x,
phi::DataType dtype = phi::DataType::UNDEFINED,
const Place& place = {});

pir::OpResult get_parameter(const std::string& name,
phi::DataType dtype,
const std::vector<int64_t>& shape);

void set_parameter(pir::OpResult parameter, const std::string& name);
void set_parameter(pir::Value parameter, const std::string& name);

pir::OpResult embedding_grad(pir::OpResult x,
pir::OpResult weight,
pir::OpResult out_grad,
pir::OpResult embedding_grad(pir::Value x,
pir::Value weight,
pir::Value out_grad,
int64_t padding_idx = -1,
bool sparse = false);

Expand Down
Loading

0 comments on commit a6459cd

Please sign in to comment.