-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Zero-Dim] support input 0D Tensor for sundary api #47734
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
|
||
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.nn.functional as F | ||
import numpy as np | ||
import unittest | ||
|
||
|
@@ -67,7 +68,7 @@ | |
] | ||
|
||
|
||
# Use to test zero-dim in the whole API | ||
# Use to test zero-dim in unary API. | ||
class TestUnaryAPI(unittest.TestCase): | ||
def test_dygraph_unary(self): | ||
paddle.disable_static() | ||
|
@@ -176,6 +177,7 @@ def test_static_unary(self): | |
] | ||
|
||
|
||
# Use to test zero-dim of reduce API | ||
class TestReduceAPI(unittest.TestCase): | ||
def test_dygraph(self): | ||
paddle.disable_static() | ||
|
@@ -232,31 +234,32 @@ def test_static(self): | |
{'func': paddle.multiply, 'cls_method': '__mul__'}, | ||
{'func': paddle.divide, 'cls_method': '__div__'}, | ||
{'func': paddle.subtract, 'cls_method': '__sub__'}, | ||
paddle.pow, | ||
] | ||
|
||
binary_api_list_without_grad = [ | ||
{'func': paddle.pow, 'cls_method': '__pow__'}, | ||
{'func': paddle.add, 'cls_method': '__add__'}, | ||
{'func': paddle.subtract, 'cls_method': '__sub__'}, | ||
{'func': paddle.multiply, 'cls_method': '__mul__'}, | ||
{'func': paddle.divide, 'cls_method': '__div__'}, | ||
{'func': paddle.subtract, 'cls_method': '__sub__'}, | ||
paddle.pow, | ||
{'func': paddle.mod, 'cls_method': '__mod__'}, | ||
paddle.floor_mod, | ||
paddle.remainder, | ||
{'func': paddle.pow, 'cls_method': '__pow__'}, | ||
] | ||
|
||
binary_api_list_without_grad = [ | ||
{'func': paddle.equal, 'cls_method': '__eq__'}, | ||
{'func': paddle.not_equal, 'cls_method': '__ne__'}, | ||
{'func': paddle.greater_equal, 'cls_method': '__ge__'}, | ||
{'func': paddle.greater_than, 'cls_method': '__gt__'}, | ||
{'func': paddle.less_equal, 'cls_method': '__le__'}, | ||
{'func': paddle.less_than, 'cls_method': '__lt__'}, | ||
{'func': paddle.remainder, 'cls_method': '__mod__'}, | ||
paddle.mod, | ||
paddle.floor_mod, | ||
paddle.logical_and, | ||
paddle.logical_or, | ||
paddle.logical_xor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已添加 |
||
] | ||
|
||
|
||
# Use to test zero-dim of binary API | ||
class TestBinaryAPI(unittest.TestCase): | ||
def test_dygraph_binary(self): | ||
paddle.disable_static() | ||
|
@@ -274,8 +277,6 @@ def test_dygraph_binary(self): | |
else: | ||
out = api(x, y) | ||
|
||
self.assertEqual(x.shape, []) | ||
self.assertEqual(y.shape, []) | ||
self.assertEqual(out.shape, []) | ||
|
||
if api not in binary_api_list_without_grad: | ||
|
@@ -296,8 +297,6 @@ def test_dygraph_binary(self): | |
else: | ||
out = api(x, y) | ||
|
||
self.assertEqual(x.shape, [2, 3, 4]) | ||
self.assertEqual(y.shape, []) | ||
self.assertEqual(out.shape, [2, 3, 4]) | ||
|
||
if api not in binary_api_list_without_grad: | ||
|
@@ -317,10 +316,7 @@ def test_dygraph_binary(self): | |
np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) | ||
else: | ||
out = api(x, y) | ||
out.backward() | ||
|
||
self.assertEqual(x.shape, []) | ||
self.assertEqual(y.shape, [2, 3, 4]) | ||
self.assertEqual(out.shape, [2, 3, 4]) | ||
|
||
if api not in binary_api_list_without_grad: | ||
|
@@ -329,19 +325,32 @@ def test_dygraph_binary(self): | |
self.assertEqual(y.grad.shape, [2, 3, 4]) | ||
self.assertEqual(out.grad.shape, [2, 3, 4]) | ||
|
||
# 4) x is 0D , y is scalar | ||
x = paddle.rand([]) | ||
y = 0.5 | ||
x.stop_gradient = False | ||
if isinstance(api, dict): | ||
out = getattr(paddle.Tensor, api['cls_method'])(x, y) | ||
self.assertEqual(out.shape, []) | ||
|
||
paddle.enable_static() | ||
|
||
def test_static_unary(self): | ||
paddle.enable_static() | ||
for api in binary_api_list: | ||
main_prog = fluid.Program() | ||
with fluid.program_guard(main_prog, fluid.Program()): | ||
# 1) x/y is 0D | ||
x = paddle.rand([]) | ||
y = paddle.rand([]) | ||
x.stop_gradient = False | ||
y.stop_gradient = False | ||
if isinstance(api, dict): | ||
out = api['func'](x, y) | ||
out_cls = getattr( | ||
paddle.static.Variable, api['cls_method'] | ||
)(x, y) | ||
self.assertEqual(out.shape, out_cls.shape) | ||
else: | ||
out = api(x, y) | ||
fluid.backward.append_backward(out) | ||
|
@@ -351,20 +360,112 @@ def test_static_unary(self): | |
block = prog.global_block() | ||
|
||
# Test compile shape | ||
self.assertEqual(x.shape, ()) | ||
self.assertEqual(y.shape, ()) | ||
self.assertEqual(out.shape, ()) | ||
|
||
exe = fluid.Executor() | ||
result = exe.run(main_prog, fetch_list=[x, y, out]) | ||
|
||
# Test runtime shape | ||
self.assertEqual(result[0].shape, ()) | ||
self.assertEqual(result[1].shape, ()) | ||
self.assertEqual(result[2].shape, ()) | ||
|
||
# 2) x is 0D , y is scalar | ||
x = paddle.rand([]) | ||
y = 0.5 | ||
x.stop_gradient = False | ||
if isinstance(api, dict): | ||
out = getattr(paddle.static.Variable, api['cls_method'])( | ||
x, y | ||
) | ||
self.assertEqual(out.shape, ()) | ||
|
||
paddle.disable_static() | ||
|
||
|
||
# Use to test zero-dim of Sundry API, which is simple and do | ||
# not have backward, or is not need to test backward in OpTest. | ||
class TestSundryAPI(unittest.TestCase): | ||
def setUp(self): | ||
self.x = paddle.rand([]) | ||
|
||
def test_linear(self): | ||
x = paddle.randn([3, 2]) | ||
w = paddle.full(shape=[2, 4], fill_value=0.5) | ||
b = paddle.zeros([]) | ||
|
||
np.testing.assert_array_equal( | ||
F.linear(x, w, b).numpy(), F.linear(x, w).numpy() | ||
) | ||
|
||
def test_is_complex(self): | ||
x = paddle.rand([]) + 1j * paddle.rand([]) | ||
self.assertTrue(paddle.is_complex(x)) | ||
|
||
def test_is_floating_point(self): | ||
self.assertTrue(paddle.is_floating_point(self.x)) | ||
|
||
def test_is_integer(self): | ||
x = paddle.randint(0, 10, []) | ||
self.assertTrue(paddle.is_integer(x)) | ||
|
||
def test_is_tensor(self): | ||
self.assertTrue(paddle.is_tensor(self.x)) | ||
|
||
def test_is_empty(self): | ||
x = paddle.rand([3, 0, 5]) | ||
self.assertTrue(paddle.is_empty(x)) | ||
|
||
def test_isfinite(self): | ||
out = paddle.isfinite(self.x) | ||
np.testing.assert_array_equal(out.numpy(), np.array(True)) | ||
|
||
def test_isinf(self): | ||
x = paddle.to_tensor(np.array(float('-inf'))) | ||
out = paddle.isinf(x) | ||
np.testing.assert_array_equal(out.numpy(), np.array(True)) | ||
|
||
def test_isnan(self): | ||
x = paddle.to_tensor(np.array(float('nan'))) | ||
out = paddle.isnan(x) | ||
np.testing.assert_array_equal(out.numpy(), np.array(True)) | ||
|
||
def test_isclose(self): | ||
out = paddle.isclose(self.x, self.x) | ||
np.testing.assert_array_equal(out.numpy(), np.array(True)) | ||
|
||
def test_clone(self): | ||
out = paddle.clone(self.x) | ||
np.testing.assert_array_equal(out.numpy(), self.x.numpy()) | ||
|
||
def test_assign(self): | ||
out = paddle.assign(self.x) | ||
np.testing.assert_array_equal(out.numpy(), self.x.numpy()) | ||
|
||
def test_item(self): | ||
x = paddle.full([], 0.5) | ||
self.assertEqual(x.item(), 0.5) | ||
|
||
def test_tolist(self): | ||
x = paddle.full([], 0.5) | ||
self.assertEqual(x.tolist(), 0.5) | ||
|
||
def test_numpy(self): | ||
x = paddle.full([], 0.5) | ||
np.testing.assert_array_equal(x.numpy(), np.array(0.5)) | ||
|
||
def test_numel(self): | ||
out = paddle.numel(self.x) | ||
self.assertEqual(out.shape, []) | ||
np.testing.assert_array_equal(out.numpy(), np.array(1)) | ||
|
||
def test_rank(self): | ||
out = paddle.rank(self.x) | ||
self.assertEqual(out.shape, []) | ||
np.testing.assert_array_equal(out.numpy(), np.array(0)) | ||
|
||
def test_shape(self): | ||
out = paddle.shape(self.x) | ||
self.assertEqual(out.shape, [0]) | ||
np.testing.assert_array_equal(out.numpy(), np.array([])) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did
paddle.subtract
repeat 4 times, whilepaddle.add
2 times, what is the meaning of repeat in the list?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add
paddle.maximum
andpaddle.minimum
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前只支持了部分二元,就是这个list中包含的。这两个还未支持0D,后面会加。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some mistake, have fixed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, thanks