-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 2】9、为 Paddle 新增 logspace API #41261
Changes from all commits
4007615
abc377f
18f7ac3
14823ed
7e33b85
7fb1d2e
fa4bac2
efe5577
30d067f
7b9c8ed
3e2e62a
fee242f
05ffd43
f0862ad
ec204b6
15ef372
88a7446
7cc079f
3530acf
407c2a6
8d04f59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <string> | ||
|
||
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/phi/core/infermeta_utils.h" | ||
#include "paddle/phi/infermeta/multiary.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class LogspaceOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
return framework::OpKernelType( | ||
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), | ||
ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class LogspaceOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("Start", | ||
"Exponent of first entry in the sequence. It is a tensor of " | ||
"shape [1], should be of type int32, int64, float32 or float64."); | ||
AddInput("Stop", | ||
"Exponent of last entry in the sequence. It is a tensor of " | ||
"shape [1], should be of type int32, int64, float32 or float64."); | ||
AddInput("Num", | ||
"Number of entry in the sequence. It is a tensor of shape [1], " | ||
"should be of type int32."); | ||
AddInput("Base", | ||
"Base of the logarithm function. It is a tensor of shape [1], " | ||
"should be of type int32, int64, float32 or float64."); | ||
AddAttr<int>("dtype", "The output data type."); | ||
AddOutput("Out", "A sequence of numbers."); | ||
AddComment(R"DOC( | ||
Return fixed number of logarithmical-evenly spaced values within a given | ||
interval. First entry is exponential of Start with base Base, and last | ||
entry is exponential of Stop with base Base. In the case when Num is 1, | ||
only exponential of Start with base Base is returned. If dtype is int32 | ||
or int64, the decimal part of values will be truncated. | ||
Like logspace function of numpy. | ||
)DOC"); | ||
} | ||
}; | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
DECLARE_INFER_SHAPE_FUNCTOR(logspace, LogspaceInferShapeFunctor, | ||
PD_INFER_META(phi::LogspaceInferMeta)); | ||
REGISTER_OPERATOR( | ||
logspace, ops::LogspaceOp, ops::LogspaceOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, | ||
LogspaceInferShapeFunctor); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/logspace_kernel.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. C++基础头文件和项目自身的头文件之间空一行,方便区分。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 完成 |
||
|
||
#include <cmath> | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/data_type_transform.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void LogspaceKernel(const Context& ctx, | ||
const DenseTensor& start, | ||
const DenseTensor& stop, | ||
const DenseTensor& number, | ||
const DenseTensor& base, | ||
DataType dtype, | ||
DenseTensor* out) { | ||
int32_t num = number.data<int32_t>()[0]; | ||
auto start_t = phi::funcs::TransDataType(ctx, start, dtype); | ||
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); | ||
auto base_t = phi::funcs::TransDataType(ctx, base, dtype); | ||
|
||
T start_data = start_t.template data<T>()[0]; | ||
T stop_data = stop_t.template data<T>()[0]; | ||
T base_data = base_t.template data<T>()[0]; | ||
PADDLE_ENFORCE_GT( | ||
num, | ||
0, | ||
phi::errors::InvalidArgument("The num of logspace op should be larger " | ||
"than 0, but received num is %d", | ||
num)); | ||
|
||
out->Resize(phi::make_ddim({num})); | ||
T* out_data = ctx.template Alloc<T>(out); | ||
|
||
if (num > 1) { | ||
// step should be of double type for all types | ||
double step = (static_cast<double>(stop_data - start_data)) / (num - 1); | ||
int half_num = num / 2; | ||
for (int i = 0; i < num; ++i) { | ||
if (i < half_num) { | ||
out_data[i] = | ||
static_cast<T>(std::pow(base_data, start_data + step * i)); | ||
} else { | ||
out_data[i] = static_cast<T>( | ||
std::pow(base_data, stop_data - step * (num - i - 1))); | ||
} | ||
} | ||
} else { | ||
out_data[0] = static_cast<T>(std::pow(base_data, start_data)); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(logspace, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::LogspaceKernel, | ||
float, | ||
int32_t, | ||
int64_t, | ||
double) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/logspace_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/copy_kernel.h" | ||
#include "paddle/phi/kernels/funcs/data_type_transform.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T> | ||
__global__ void LogspaceKernelInner( | ||
T start, T stop, double step, T base, int64_t size, T* out) { | ||
int64_t index = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
for (; index < size; index += blockDim.x * gridDim.x) { | ||
if (index < size / 2) { | ||
out[index] = | ||
static_cast<T>(pow(static_cast<double>(base), | ||
static_cast<double>(start + step * index))); | ||
} else { | ||
out[index] = static_cast<T>( | ||
pow(static_cast<double>(base), | ||
static_cast<double>(stop - step * (size - index - 1)))); | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void LogspaceSpecialKernel(T start, T base, T* out) { | ||
out[0] = static_cast<T>( | ||
pow(static_cast<double>(base), static_cast<double>(start))); | ||
} | ||
|
||
template <typename T, typename Context> | ||
void LogspaceKernel(const Context& ctx, | ||
const DenseTensor& start, | ||
const DenseTensor& stop, | ||
const DenseTensor& number, | ||
const DenseTensor& base, | ||
DataType dtype, | ||
DenseTensor* out) { | ||
auto start_t = phi::funcs::TransDataType(ctx, start, dtype); | ||
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); | ||
auto base_t = phi::funcs::TransDataType(ctx, base, dtype); | ||
|
||
DenseTensor n_start; | ||
DenseTensor n_stop; | ||
DenseTensor n_num; | ||
DenseTensor n_base; | ||
phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start); | ||
T start_data = n_start.data<T>()[0]; | ||
phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop); | ||
T stop_data = n_stop.data<T>()[0]; | ||
phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num); | ||
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]); | ||
phi::Copy(ctx, base_t, phi::CPUPlace(), false, &n_base); | ||
T base_data = n_base.data<T>()[0]; | ||
|
||
PADDLE_ENFORCE_GT( | ||
num, | ||
0, | ||
phi::errors::InvalidArgument("The num of logspace op should be larger " | ||
"than 0, but received num is %d", | ||
num)); | ||
|
||
out->Resize(phi::make_ddim({num})); | ||
T* out_data = ctx.template Alloc<T>(out); | ||
|
||
double step = 0; | ||
auto stream = ctx.stream(); | ||
int block = 512; | ||
int grid = (num + block - 1) / block; | ||
if (num != 1) { | ||
step = (static_cast<double>(stop_data - start_data)) / (num - 1); | ||
LogspaceKernelInner<T><<<grid, block, 0, stream>>>( | ||
start_data, stop_data, step, base_data, num, out_data); | ||
} else { | ||
LogspaceSpecialKernel<T><<<grid, block, 0, stream>>>( | ||
start_data, base_data, out_data); | ||
} | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(logspace, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::LogspaceKernel, | ||
float, | ||
int32_t, | ||
int64_t, | ||
double) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的license格式好像还是有点问题,可以后续再完善下 |
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/core/dense_tensor.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename Context> | ||
void LogspaceKernel(const Context& ctx, | ||
const DenseTensor& start, | ||
const DenseTensor& stop, | ||
const DenseTensor& number, | ||
const DenseTensor& base, | ||
DataType dtype, | ||
DenseTensor* out); | ||
|
||
} // namespace phi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++基础头文件和项目自身的头文件之间空一行,方便区分。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
完成