Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Improve Dtype Transfer (add infermeta by value) #60677

Merged
merged 12 commits into from
Jan 15, 2024

Conversation

chen2016013
Copy link
Contributor

@chen2016013 chen2016013 commented Jan 9, 2024

PR types

Others

PR changes

Others

Description

背景: 在运算中,有一种场景如下:对于某个二元运算op,输入x.dype()=float64,y.dtype()=complex128,infermeta结果float64(按照第一个输入确定)。在pass中会对x做data transfer以便和y进行计算,此时x.dype()=y.dtype()=complex128,但op_result的结果及后续使用该输出的相关op的输入输出类型没有做transfer,仍然是float64,导致错误。

修改方案:在构建kernel op output前,对于所有op再进行一轮Infermeta。具体地,将Infermeta抽象出相应接口在pd_op_to_kernel_pass中调用

特例:用户上层组网时可能指定输出的 dtype,这部分dtype不应该被改变

Pcard-67164

Copy link

paddle-bot bot commented Jan 9, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhangbo9674
Copy link
Contributor

请完善一下 pr 描述

@chen2016013 chen2016013 changed the title [PIR] Infermeta by value [PIR] Improve Dtype Transfer (add infermeta by value) Jan 15, 2024
Copy link
Contributor

@zhangbo9674 zhangbo9674 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM over rall

return get_attributes_str


def gen_infermeta_func_str(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_build_gen.py 文件中原本算子的 Build 函数中,可以替换为该infer_meta 接口?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单独提PR修改

@@ -177,6 +177,71 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}

std::vector<pir::Type> AddNOp::InferMeta(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将 Build 函数中相关逻辑进行替换

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单独提PR修改

"DenseTensorType and SelectedRowsType."));
}
}

std::string GetValueDataType(const pir::Value& value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以调用新增的GetValueDataType(pir::Type type)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单独提PR修改

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chen2016013 chen2016013 merged commit de37c94 into PaddlePaddle:develop Jan 15, 2024
29 checks passed
@chen2016013 chen2016013 deleted the infermeta2 branch January 17, 2024 06:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants