-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR] Improve Dtype Transfer (add infermeta by value) #60677
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
请完善一下 pr 描述 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM over rall
return get_attributes_str | ||
|
||
|
||
def gen_infermeta_func_str( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_build_gen.py 文件中原本算子的 Build 函数中,可以替换为该infer_meta 接口?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单独提PR修改
@@ -177,6 +177,71 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { | |||
fn(infer_meta); | |||
} | |||
|
|||
std::vector<pir::Type> AddNOp::InferMeta( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
将 Build 函数中相关逻辑进行替换
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单独提PR修改
"DenseTensorType and SelectedRowsType.")); | ||
} | ||
} | ||
|
||
std::string GetValueDataType(const pir::Value& value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以调用新增的GetValueDataType(pir::Type type)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单独提PR修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
背景: 在运算中,有一种场景如下:对于某个二元运算op,输入x.dype()=float64,y.dtype()=complex128,infermeta结果float64(按照第一个输入确定)。在pass中会对x做data transfer以便和y进行计算,此时x.dype()=y.dtype()=complex128,但op_result的结果及后续使用该输出的相关op的输入输出类型没有做transfer,仍然是float64,导致错误。
修改方案:在构建kernel op output前,对于所有op再进行一轮Infermeta。具体地,将Infermeta抽象出相应接口在pd_op_to_kernel_pass中调用
特例:用户上层组网时可能指定输出的 dtype,这部分dtype不应该被改变
Pcard-67164