Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 4th No.24】为 Paddle 新增 paddle.sparse.is_nan 稀疏 API #51513

Merged
merged 31 commits into from
Apr 3, 2023

Conversation

thunder95
Copy link
Contributor

PR types

New features

PR changes

APIs

Describe

完成第四期第24项目开发任务: https://github.com/PaddlePaddle/community/blob/master/hackthon_4th/%E3%80%90PaddlePaddle%20Hackathon%204%E3%80%91%20%E6%A0%B8%E5%BF%83%E6%A1%86%E6%9E%B6%E5%BC%80%E6%BA%90%E8%B4%A1%E7%8C%AE%20API%20%E5%BC%80%E5%8F%91%E4%BB%BB%E5%8A%A1%E5%90%88%E9%9B%86.md#task24
isnan 检查输入Tensor 的每一个值是否为 +/-NaN, 并返回布尔型结果。目前在 PaddlePaddle 中,对于稀疏Tensor还没有支持isnan的API。

RFC设计文档: PaddlePaddle/community#415
中文api文档:PaddlePaddle/docs#5705

@paddle-bot
Copy link

paddle-bot bot commented Mar 11, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -170,6 +171,20 @@ SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) {
return csr;
}

template <typename T, typename Context>
SparseCooTensor IsnanCoo(const Context& dev_ctx, const SparseCooTensor& x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个额外地C++ API可以删掉,一般用kernel表示C++ api就可以

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这个删掉没

@zhwesky2010
Copy link
Contributor

没有支持静态图,可能需要 @zkh2016 帮忙看下静态图的应该怎么写

@zkh2016
Copy link
Contributor

zkh2016 commented Mar 15, 2023

没有支持静态图,可能需要 @zkh2016 帮忙看下静态图的应该怎么写

coo格式的静态图单测还是建议加下。test_sparse_norm_op.py和test_sparse_conv_op.py有测试例子,可以先用这种方式测下,验下正确性。

@thunder95
Copy link
Contributor Author

test_sparse_norm_op

@zhouwei25 @zkh2016 这个算子有点特殊, 计算结果是bool型, 在to_dense或return_numpy=True的时候,都会报错。辛苦两位老师给个建议。

InvalidArgumentError: The type of data we are trying to retrieve (float32) does not match the type of data (bool) currently contained in the container.
1220: [Hint: Expected dtype() == paddle::experimental::CppTypeToDataType::Type(), but received dtype():1 != paddle::experimental::CppTypeToDataType::Type():10.] (at /paddle/paddle/phi/core/dense_tensor.cc:163)

@zkh2016
Copy link
Contributor

zkh2016 commented Mar 16, 2023

test_sparse_norm_op

@zhouwei25 @zkh2016 这个算子有点特殊, 计算结果是bool型, 在to_dense或return_numpy=True的时候,都会报错。辛苦两位老师给个建议。

InvalidArgumentError: The type of data we are trying to retrieve (float32) does not match the type of data (bool) currently contained in the container. 1220: [Hint: Expected dtype() == paddle::experimental::CppTypeToDataType::Type(), but received dtype():1 != paddle::experimental::CppTypeToDataType::Type():10.] (at /paddle/paddle/phi/core/dense_tensor.cc:163)

当前to_dense还没注册bool类型,你可以先注册一个,测测看。

@thunder95
Copy link
Contributor Author

thunder95 commented Mar 19, 2023

@zhouwei25 @zkh2016
已尝试注册bool类型,依旧是同样的报错。

有个不明白的地方,虽然我使用的是IsfiniteInferMeta, 返回的稀疏张量的dtype应该是bool才对,但是打印出来仍旧是float类型,所以定位在下面的内存或显存分配会发生报错,因为x的dtype是bool类型,但是模板T是float:

template <typename T, typename Context> DenseTensor CooToDense(const Context& dev_ctx, const SparseCooTensor& x) const T* x_data = values.data<T>(); ===> 出现类型转换错误
打印稀疏张量的时候,dtype并不是bool类型,所以注册bool类型没有产生效果

Tensor(shape=[2, 2, 2], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, indices=[[0, 0, 1, 1], [1, 1, 1, 1], [0, 1, 0, 1]], values=[False, False, False, True ])

@zkh2016
Copy link
Contributor

zkh2016 commented Mar 20, 2023

@zhouwei25 @zkh2016 已尝试注册bool类型,依旧是同样的报错。

有个不明白的地方,虽然我使用的是IsfiniteInferMeta, 返回的稀疏张量的dtype应该是bool才对,但是打印出来仍旧是float类型,所以定位在下面的内存或显存分配会发生报错,因为x的dtype是bool类型,但是模板T是float:

template <typename T, typename Context> DenseTensor CooToDense(const Context& dev_ctx, const SparseCooTensor& x) const T* x_data = values.data<T>(); ===> 出现类型转换错误 打印稀疏张量的时候,dtype并不是bool类型,所以注册bool类型没有产生效果

Tensor(shape=[2, 2, 2], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, indices=[[0, 0, 1, 1], [1, 1, 1, 1], [0, 1, 0, 1]], values=[False, False, False, True ])

可能需要辛苦你进一步定位下,我看了下当前是通过DEFINE_SPARSE_UNARY_KERNEL这个宏定义is_nan的kernel的,这里面重新创建了out,所以类型可能变了。

@thunder95
Copy link
Contributor Author

@zhouwei25 @zkh2016 辛苦两位老师再review一下

args : (Tensor x)
output : Tensor(out)
infer_meta :
func : IsfiniteInferMeta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个名字不太对?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsfiniteInferMeta吗?最新提交已修改成unchanged

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhouwei25 如果改成unchanged,还是会报错,最新提交又修改回IsfiniteInferMeta
InvalidArgumentError: The type of data we are trying to retrieve (float32) does not match the type of data (bool) currently contained in the container.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhouwei25 如果改成unchanged,还是会报错,最新提交又修改回IsfiniteInferMeta InvalidArgumentError: The type of data we are trying to retrieve (float32) does not match the type of data (bool) currently contained in the container.

就是命名风格有点问题,这个不是IsNanInferMeta吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhouwei25 IsfiniteInferMeta是参考dense tensor那里设计的IsfiniteInferMeta, 考虑到代码可能会冗余就这么直接复用了,老师建议这个地方是需要单独写一个IsNanInferMeta吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PD_REGISTER_INFER_META_FN(isnan, phi::IsfiniteInferMeta);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好复用也可以

@@ -170,6 +171,20 @@ SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) {
return csr;
}

template <typename T, typename Context>
SparseCooTensor IsnanCoo(const Context& dev_ctx, const SparseCooTensor& x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这个删掉没

}

template <typename T, typename Context>
SparseCooTensor IsnanCsr(const Context& dev_ctx, const SparseCooTensor& x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个删掉,需要kernel就可以

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删掉

@@ -219,5 +220,44 @@ void CastCsrKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void IsnanCooKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个用目前的公共组件,宏函数来注册kernel。可以复用减少代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhouwei25 公共组件不能满足这个算子,因为emptylike会创建一个相同类型的输出tensor,而这个算子输出是bool型的,所以这里单独写了个kernel。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhouwei25 公共组件不能满足这个算子,因为emptylike会创建一个相同类型的输出tensor,而这个算子输出是bool型的,所以这里单独写了个kernel。

OK

zhwesky2010
zhwesky2010 previously approved these changes Mar 29, 2023
Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

jeff41404
jeff41404 previously approved these changes Mar 29, 2023
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Return whether every element of input tensor is `NaN` or not, requiring x to be a SparseCooTensor or SparseCsrTensor.

Args:
x (Tensor): The input tensor, it's data type should be float16, float32, float64, int32, int64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文文档中 '可以为 Coo 或 Csr 格式' 可以表达出来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunzhongkai588 已修改

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for CI-OP-Benchmark

@luotao1 luotao1 closed this Apr 3, 2023
@luotao1 luotao1 reopened this Apr 3, 2023
@luotao1 luotao1 merged commit b7db6af into PaddlePaddle:develop Apr 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants