Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpTest] support prim test in OpTest #50509

Merged
merged 29 commits into from
Feb 21, 2023

Conversation

Charles-hit
Copy link
Contributor

@Charles-hit Charles-hit commented Feb 14, 2023

PR types

New features

PR changes

Others

Describe

任务背景:
组合算子+编译器协同任务需要对算子精度进行保障,为了复用框架原有的算子单测case,本PR在OpTest中进行改造,新增了组合测试功能对算子精度进行保障。
PR改动:
1.对OpTest新增组合测试功能。
2.修复OpTest框架如果某个单测挂了导致动态图静态图运行模式不正确的问题,对静态图代码利用guard进行保护。
3.对softmax、expand、reduce_sum三个算子添加组合单测。

@paddle-bot
Copy link

paddle-bot bot commented Feb 14, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

return kernel_sig

def is_only_check_prim(self):
return self.only_prim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only test prim

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用户在新增一些组合单测会触发Op测试,但是部分Op规定不可以测试fp32类型,导致单测挂掉。对于组合测试并没有这个要求,这个开关用来跳过非组合的测试。

)
if not kernel_sig:
return None
assert hasattr(self, "python_api"), (
"Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_dygraph = True"
% self.op_type
)
args = prepare_python_api_arguments(
args = OpTestUtils.prepare_python_api_arguments(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add todo: those code change will recover after delete legacy dygraph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

inplace_atol=None,
):
core._set_prim_all_enabled(False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_prim_forward_enabled enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在跑这个前向case中会有反向inplace的测试,这儿是修复当某个单测挂掉导致前反向开关没有关,走到组合inplace逻辑中导致段错误。

"rev_comp": {"rtol": 1e-2, "atol": 1e-2},
"cinn": {"rtol": 1e-1, "atol": 1e-1},
},
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要支持bfloat16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy没有bfloat16数据类型,这儿用unit16来表示bfloat16,目前python api都是这样做的

if check_prim:
prim_checker = PrimForwardChecker(self, place)
prim_checker.check()
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用户会感知这里的dtype逻辑么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会

@@ -66,6 +55,8 @@ def setUp(self):
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.gradient = self.calc_gradient()
# error occurred in cinn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add todo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

未来我们会用代码扫一遍看看哪些case没开cinn,最后会集中处理。

@@ -1265,6 +1265,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'x',
[
'bool',
'uint16',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要新增这个数据类型呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为这个在python端拿来表示bfloat16数据类型(numpy没有bfloat16,所以用这个来表示),静态图目前bfloat16数据类型流程没有完全测试过,这儿新增uint16来测试sum算子静态图下bfloat16数据类型。

Copy link
Contributor

@xiaoguoguo626807 xiaoguoguo626807 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cyber-pioneer
Copy link
Contributor

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Charles-hit Charles-hit merged commit 457defe into PaddlePaddle:develop Feb 21, 2023
Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments for now


import numpy as np

TOLERANCE = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dose this config used only for op test?

@@ -401,6 +401,7 @@ def is_custom_device_op_test():
and not is_npu_op_test()
and not is_mlu_op_test()
and not is_custom_device_op_test()
and not cls.check_prim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this?


sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make OpTestUtils in an independent file instead of in prim_op_test, since it's not only for prim op test

program = Program()
block = program.global_block()
op = self._append_ops(block)
with paddle.fluid.framework._dygraph_guard(None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use fluid api?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why dygraph_guard, there are all static operations below

grad_program
).with_data_parallel(
loss_name="", build_strategy=build_strategy, places=place
with paddle.fluid.framework._dygraph_guard(None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same...using program_guard

self.checker_name = "PrimForwardChecker"
self.place = place
self.op_test = op_test
self.save_eager_or_static_status()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this?

def init_checker(self):
assert hasattr(
self.op_test, 'prim_op_type'
), "if you want to test comp op, please set prim_op_type in setUp function."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more comments for prim_op_type, what is it?

assert hasattr(
self.op_test, 'dtype'
), "Please set dtype in setUp function."
self.op_type = self.op_test.op_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use self.op_test.xxx directly

@Charles-hit Charles-hit deleted the prim_test_frame branch February 28, 2023 07:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants