Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] add where, atan2, median 0-Dim ut #49692

Merged
merged 5 commits into from
Jan 13, 2023

Conversation

ronny1996
Copy link
Contributor

@ronny1996 ronny1996 commented Jan 10, 2023

PR types

Others

PR changes

Others

Describe

add where, atan2, median 0-Dim ut

@paddle-bot
Copy link

paddle-bot bot commented Jan 10, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ronny1996 ronny1996 changed the title add where, atan2, median 0d ut add where, atan2, median 0-Dim ut Jan 11, 2023
@ronny1996 ronny1996 changed the title add where, atan2, median 0-Dim ut [Zero-Dim] add where, atan2, median 0-Dim ut Jan 12, 2023
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 1)

def test_atan2(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个api是二元支持广播的吗?支持的话可以直接放到binary_api_list里

Copy link
Contributor Author

@ronny1996 ronny1996 Jan 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atan和where不支持广播

x1.stop_gradient = False
x2.stop_gradient = False
out = paddle.where(x1 > x2, x1, x2)
paddle.static.append_backward(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.static.append_backward(out.sum()) 然后测下前反向的值和shape吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新

self.assertEqual(res[0].shape, ())

@prog_scope()
def test_atan2(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同广播问题

@@ -43,6 +43,10 @@ void WhereKernel(const Context& ctx,

int ret = xpu::select(
ctx.x_context(), cond_data, x_data, y_data, out_data, cond_dims, x_dims);

if (cond_dims.size() == 0 && x_dims.size() == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个按道理kernel内部不会单独去改变shape,如果infermeta是对的,这个还需要改不

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只在两个输入都是0维时需要,前面为了计算0维,做了特殊处理

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ronny1996 ronny1996 merged commit 1508cae into PaddlePaddle:develop Jan 13, 2023
@ronny1996 ronny1996 deleted the 0dut branch January 13, 2023 09:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants