Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Reconstruct the Verify system #58052

Merged
merged 9 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::vector<pir::Operation *> GroupOp::ops() {
inner_block->end());
}

void GroupOp::Verify() {}
void GroupOp::VerifySig() {}

void GroupOp::Print(pir::IrPrinter &printer) {
auto &os = printer.os;
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class GroupOp : public pir::Op<GroupOp> {
pir::Block *block();
std::vector<pir::Operation *> ops();

void Verify();
void VerifySig();
void Print(pir::IrPrinter &printer); // NOLINT
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace dialect {

const char* JitKernelOp::attributes_name[attributes_num] = {kAttrName};

void JitKernelOp::Verify() {
void JitKernelOp::VerifySig() {
VLOG(4) << "Verifying inputs, outputs and attributes for: JitKernelOp.";

auto& attributes = this->attributes();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class JitKernelOp : public ::pir::Op<JitKernelOp> {

hlir::framework::Instruction* instruction();

void Verify();
void VerifySig();
};

} // namespace dialect
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
true);
}
VLOG(4) << "[general op][conditional_block] IfOp false block translate end.";

operation->Verify();
VLOG(4) << "[general op][conditional_block] IfOp translate end.";
return operation;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const char* PhiKernelOp::attributes_name[attributes_num] = { // NOLINT
"kernel_name",
"kernel_key"};

void PhiKernelOp::Verify() {
void PhiKernelOp::VerifySig() {
VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp.";

auto& attributes = this->attributes();
Expand Down Expand Up @@ -64,7 +64,7 @@ const char* LegacyKernelOp::attributes_name[attributes_num] = { // NOLINT
"kernel_name",
"kernel_key"};

void LegacyKernelOp::Verify() {
void LegacyKernelOp::VerifySig() {
VLOG(4) << "Verifying inputs, outputs and attributes for: LegacyKernelOp.";

auto& attributes = this->attributes();
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class PhiKernelOp : public pir::Op<PhiKernelOp> {
std::string op_name();
std::string kernel_name();
phi::KernelKey kernel_key();
void Verify();
void VerifySig();
};

class LegacyKernelOp : public pir::Op<LegacyKernelOp> {
Expand All @@ -41,7 +41,7 @@ class LegacyKernelOp : public pir::Op<LegacyKernelOp> {
std::string op_name();
std::string kernel_name();
phi::KernelKey kernel_key();
void Verify();
void VerifySig();
};

} // namespace dialect
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
{build_mutable_attr_is_input}
{build_attr_num_over_1}
{build_mutable_attr_is_input_attr_num_over_1}
void Verify();
void VerifySig();
{get_inputs_and_outputs}
{exclusive_interface}
}};
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_verify_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::Verify() {{
void {op_name}::VerifySig() {{
VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}.";
VLOG(4) << "Verifying inputs:";
{{
Expand All @@ -36,7 +36,7 @@
"""

GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::Verify() {{}}
void {op_name}::VerifySig() {{}}
"""

INPUT_TYPE_CHECK_TEMPLATE = """
Expand Down
70 changes: 69 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp

#include "paddle/phi/core/enforce.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/operation_utils.h"
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"
Expand Down Expand Up @@ -109,7 +110,74 @@ void IfOp::Print(pir::IrPrinter &printer) {
}
os << "\n }";
}
void IfOp::Verify() {}
void IfOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: IfOp.";
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", input_size));

if ((*this)->operand_source(0).type().isa<pir::DenseTensorType>()) {
PADDLE_ENFORCE(
(*this)
->operand_source(0)
.type()
.dyn_cast<pir::DenseTensorType>()
.dtype()
.isa<pir::BoolType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input, it should be a "
"bool DenseTensorType."));
}

PADDLE_ENFORCE_EQ((*this)->num_regions(),
2u,
phi::errors::PreconditionNotMet(
"The size %d of regions must be equal to 2.",
(*this)->num_regions()));
}

void IfOp::VerifyRegion() {
VLOG(4) << "Start Verifying sub regions for: IfOp.";
PADDLE_ENFORCE_EQ(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true_region和false_region的size可以不相等。
true_region的size一定等于1。 如果输入为空,false_region的size可以为0, 否则,一定为1;

(*this)->region(0).size(),
1u,
phi::errors::PreconditionNotMet("The size %d of true_region must be 1.",
(*this)->region(0).size()));

if ((*this)->num_results() != 0) {
PADDLE_ENFORCE_EQ(
(*this)->region(0).size(),
(*this)->region(1).size(),
phi::errors::PreconditionNotMet("The size %d of true_region must be "
"equal to the size %d of false_region.",
(*this)->region(0).size(),
(*this)->region(1).size()));

auto *true_last_op = (*this)->region(0).front()->back();
auto *false_last_op = (*this)->region(1).front()->back();
PADDLE_ENFORCE_EQ(true_last_op->isa<pir::YieldOp>(),
true,
phi::errors::PreconditionNotMet(
"The last of true block must be YieldOp"));
PADDLE_ENFORCE_EQ(true_last_op->num_operands(),
(*this)->num_results(),
phi::errors::PreconditionNotMet(
"The size of last of true block op's input must be "
"equal to IfOp's outputs num."));
PADDLE_ENFORCE_EQ(false_last_op->isa<pir::YieldOp>(),
true,
phi::errors::PreconditionNotMet(
"The last of false block must be YieldOp"));
PADDLE_ENFORCE_EQ(false_last_op->num_operands(),
(*this)->num_results(),
phi::errors::PreconditionNotMet(
"The size of last of false block op's input must be "
"equal to IfOp's outputs num."));
}
}

void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class IfOp : public pir::Op<IfOp> {
pir::Block *true_block();
pir::Block *false_block();
void Print(pir::IrPrinter &printer); // NOLINT
void Verify();
void VerifySig();
void VerifyRegion();
};

class WhileOp : public pir::Op<WhileOp> {
Expand All @@ -57,7 +58,8 @@ class WhileOp : public pir::Op<WhileOp> {
pir::Block *cond_block();
pir::Block *body_block();
void Print(pir::IrPrinter &printer); // NOLINT
void Verify() {}
void VerifySig() {}
void VerifyRegion() {}
};

} // namespace dialect
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ OpInfoTuple AddNOp::GetOpInfo() {
return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n");
}

void AddNOp::Verify() {
void AddNOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddNOp.";
VLOG(4) << "Verifying inputs:";
{
Expand Down Expand Up @@ -222,7 +222,7 @@ void AddN_Op::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void AddN_Op::Verify() {
void AddN_Op::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddN_Op.";
VLOG(4) << "Verifying inputs:";
{
Expand Down Expand Up @@ -345,7 +345,7 @@ void AddNWithKernelOp::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void AddNWithKernelOp::Verify() {
void AddNWithKernelOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"AddNWithKernelOp.";
VLOG(4) << "Verifying inputs:";
Expand Down Expand Up @@ -561,7 +561,7 @@ void FusedGemmEpilogueOp::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void FusedGemmEpilogueOp::Verify() {
void FusedGemmEpilogueOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
"FusedGemmEpilogueOp.";
VLOG(4) << "Verifying inputs:";
Expand Down Expand Up @@ -833,7 +833,7 @@ void FusedGemmEpilogueGradOp::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void FusedGemmEpilogueGradOp::Verify() {}
void FusedGemmEpilogueGradOp::VerifySig() {}

void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta);
Expand Down Expand Up @@ -983,7 +983,7 @@ void SplitGradOp::Build(pir::Builder &builder,
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}

void SplitGradOp::Verify() {
void SplitGradOp::VerifySig() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp.";
VLOG(4) << "Verifying inputs:";
{
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AddNOp : public pir::Op<AddNOp,
pir::OperationArgument &argument, // NOLINT
pir::Value inputs);

void Verify();
void VerifySig();
pir::Value inputs() { return operand_source(0); }
pir::OpResult out() { return result(0); }
static void InferMeta(phi::InferMetaContext *infer_meta);
Expand All @@ -69,7 +69,7 @@ class AddN_Op : public pir::Op<AddN_Op,
pir::OperationArgument &argument, // NOLINT
pir::Value inputs_);

void Verify();
void VerifySig();
pir::Value inputs() { return operand_source(0); }
pir::OpResult out() { return result(0); }

Expand All @@ -89,7 +89,7 @@ class AddNWithKernelOp : public pir::Op<AddNWithKernelOp,
pir::OperationArgument &argument, // NOLINT
pir::Value inputs_);

void Verify();
void VerifySig();
pir::Value inputs() { return operand_source(0); }
pir::OpResult out() { return result(0); }

Expand All @@ -113,7 +113,7 @@ class FusedGemmEpilogueOp
pir::Value y_,
pir::Value bias_,
pir::AttributeMap attributes);
void Verify();
void VerifySig();
pir::Value x() { return operand_source(0); }
pir::Value y() { return operand_source(1); }
pir::Value bias() { return operand_source(2); }
Expand Down Expand Up @@ -141,7 +141,7 @@ class FusedGemmEpilogueGradOp
pir::Value reserve_space_,
pir::Value out_grad_,
pir::AttributeMap attributes);
void Verify();
void VerifySig();
pir::Value x() { return operand_source(0); }
pir::Value y() { return operand_source(1); }
pir::Value reserve_space() { return operand_source(2); }
Expand Down Expand Up @@ -169,7 +169,7 @@ class SplitGradOp : public pir::Op<SplitGradOp, OpYamlInfoInterface> {
pir::Value out_grad_,
pir::Value axis_);

void Verify();
void VerifySig();
pir::Value out_grad() { return operand_source(0); }
pir::Value axis() { return operand_source(1); }
pir::OpResult x_grad() { return result(0); }
Expand Down
16 changes: 8 additions & 8 deletions paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void ModuleOp::Destroy() {
}
}

void ModuleOp::Verify() const {
void ModuleOp::VerifySig() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
Expand Down Expand Up @@ -118,7 +118,7 @@ void GetParameterOp::PassStopGradients(OperationArgument &argument) {
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void GetParameterOp::Verify() const {
void GetParameterOp::VerifySig() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
Expand All @@ -144,7 +144,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
argument.AddAttribute(attributes_name[0],
pir::StrAttribute::get(builder.ir_context(), name));
}
void SetParameterOp::Verify() const {
void SetParameterOp::VerifySig() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
Expand All @@ -170,7 +170,7 @@ void ShadowOutputOp::Build(Builder &builder, // NOLINT
argument.AddAttribute(attributes_name[0],
pir::StrAttribute::get(builder.ir_context(), name));
}
void ShadowOutputOp::Verify() const {
void ShadowOutputOp::VerifySig() const {
VLOG(4) << "Verifying inputs, outputs and attributes for: ShadowOutputOp.";
// Verify inputs:
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
Expand Down Expand Up @@ -198,7 +198,7 @@ void CombineOp::Build(Builder &builder,
PassStopGradientsDefaultly(argument);
}

void CombineOp::Verify() const {
void CombineOp::VerifySig() const {
// outputs.size() == 1
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");

Expand Down Expand Up @@ -260,7 +260,7 @@ void SliceOp::PassStopGradients(OperationArgument &argument, int index) {
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void SliceOp::Verify() const {
void SliceOp::VerifySig() const {
// inputs.size() == 1
auto input_size = num_operands();
IR_ENFORCE(
Expand Down Expand Up @@ -364,7 +364,7 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
}

void SplitOp::Verify() const {
void SplitOp::VerifySig() const {
// inputs.size() == 1
IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1.");

Expand Down Expand Up @@ -393,7 +393,7 @@ void ConstantOp::Build(Builder &builder,
argument.output_types.push_back(output_type);
}

void ConstantOp::Verify() const {
void ConstantOp::VerifySig() const {
IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0.");
IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(attributes().count("value") > 0, "must has value attribute");
Expand Down
Loading