Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi]Move kron kernel to phi #40427

Merged
merged 14 commits into from
Mar 15, 2022
Merged

[Phi]Move kron kernel to phi #40427

merged 14 commits into from
Mar 15, 2022

Conversation

ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Mar 10, 2022

PR types

Others

PR changes

Others

Describe

[Phi]Move kron kernel to phi

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

// limitations under the License.

#include "paddle/phi/kernels/impl/kron_grad_kernel_impl.h"
#include "paddle/phi/kernels/kron_grad_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件放到第一行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

// limitations under the License.

#include "paddle/phi/kernels/impl/kron_kernel_impl.h"
#include "paddle/phi/kernels/kron_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

namespace phi {

namespace ops = paddle::operators;
namespace plat = paddle::platform;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可以用phi::dtype 命名空间下的complex ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 21 to 22
namespace ops = paddle::operators;
namespace plat = paddle::platform;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在phi下最好还是不使用paddle::xxx相关namespace的别名,会增加后续替换的难度

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

const plat::complex<T>* dout_;
const plat::complex<T>* A_;
const plat::complex<T>* B_;
plat::complex<T>* dout_a_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plat->phi::dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

p_dout_y = dout_y.data<T>();
}

plat::ForRange<Context> for_range(dev_ctx, numel);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用phi下的ForRange

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 239 to 244
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1}, stream);
}
if (dy) {
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*ctx, dout_y, dy, kps::IdentityFunctor<T>(), {1}, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorReduceImpl可以使用phi下的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 20 to 21
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_registry.h这里应该不需要了
for_range.h使用phi下的

#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce_op.cu.h可以使用paddle/phi/kernels/funcs/reduce_function.h代替

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 19 to 21
KernelSignature KronOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("kron", {"X", "Y"}, {}, {"Out"});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的前向ArgumentMapping看上去没有特殊case,感觉可以不写,试试直接使用默认的参数映射能不能work?

auto stream = dev_ctx.stream(); // it is a cuda device_context
auto* ctx = reinterpret_cast<const plat::CUDADeviceContext*>(&dev_ctx);
if (dx) {
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops::TensorReduceImpl 已迁移,这里可用 phi::funcs::ReduceKernel

*ctx, dout_x, dx, kps::IdentityFunctor<T>(), {1}, stream);
}
if (dy) {
ops::TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

#include <algorithm>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi目录下用不到原来的op 注册头文件

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

引用这个 paddle/phi/kernels/funcs/for_range.h

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops::TensorReduceImpl 已迁移,这里可用 phi::funcs::ReduceKernel , 此头文件可以不用了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

p_shape_y = dim_y.Get();
#endif

paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用Phi下的ForRange替代

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Copy link
Contributor

@MingMingShangTian MingMingShangTian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZzSean ZzSean merged commit f181d47 into PaddlePaddle:develop Mar 15, 2022
@ZzSean ZzSean deleted the move_kron branch April 14, 2022 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants