Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support 0D Tensor input for topk/broadcast_to/expand/expand_as/broadcast_shape #50536

Merged
merged 25 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ void CumInferMeta(const MetaTensor& x,
out->set_dims(x_dims);
out->set_dtype(x.dtype());
}

out->share_lod(x);
}

Expand Down Expand Up @@ -970,7 +971,7 @@ void ExpandInferMeta(const MetaTensor& x,
MAX_RANK_SUPPORTED));
PADDLE_ENFORCE_GE(
expand_shape.size(),
1,
0,
phi::errors::InvalidArgument("The number of elements (%d) of 'shape' for "
"must be a positive integer.",
expand_shape.size()));
Expand Down Expand Up @@ -1005,7 +1006,7 @@ void ExpandInferMeta(const MetaTensor& x,

out->set_dims(make_ddim(out_shape));
out->set_dtype(x.dtype());
if (out_shape[0] == x_dims[0]) {
if (out_rank > 0 && out_shape[0] == x_dims[0]) {
out->share_lod(x);
}
}
Expand Down Expand Up @@ -4097,14 +4098,23 @@ void TopKInferMeta(const MetaTensor& x,
MetaConfig config) {
auto input_dims = x.dims();
const int& dim_size = input_dims.size();
PADDLE_ENFORCE_EQ(
(axis < dim_size) && (axis >= (-1 * dim_size)),
true,
phi::errors::InvalidArgument(
"the axis of topk must be [-%d, %d), but you set axis is %d",
dim_size,
dim_size,
axis));
if (dim_size != 0) {
PADDLE_ENFORCE_EQ(
(axis < dim_size) && (axis >= (-1 * dim_size)),
true,
phi::errors::InvalidArgument(
"the axis of topk must be [-%d, %d), but you set axis is %d",
dim_size,
dim_size,
axis));
} else {
PADDLE_ENFORCE_EQ(
(axis == dim_size) || (axis == -1),
true,
phi::errors::InvalidArgument("the axis of topk must be 0 or -1 when "
"x.dims() = 0, but you set axis is %d",
axis));
}

if (axis < 0) axis += dim_size;

Expand All @@ -4122,12 +4132,13 @@ void TopKInferMeta(const MetaTensor& x,

PADDLE_ENFORCE_GE(
input_dims.size(),
1,
phi::errors::InvalidArgument("input of topk must have >= 1d shape"));
0,
phi::errors::InvalidArgument("input of topk must have >= 0d shape"));

phi::DDim dims = input_dims;

dims[axis] = k;
if (input_dims.size() > 0) {
dims[axis] = k;
}
out->set_dims(dims);
out->share_lod(x);
out->set_dtype(x.dtype());
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/cpu/top_k_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ void TopkGradKernel(const Context& dev_ctx,
axis = (axis < 0) ? (in_dims.size() + axis) : axis;

T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
return;
}

if (axis + 1 == in_dims.size()) {
// allocate the memory for the input_grad

Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/kernels/cpu/top_k_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,13 @@ void TopkKernel(const Context& dev_ctx,
const auto* input = &x;
// Get the top k elements of each row of input tensor
const auto& in_dims = input->dims();

// 0d input x
if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);
indices_data[0] = 0;
yunyaoXYY marked this conversation as resolved.
Show resolved Hide resolved
return;
}
// axis < 0, cacluate the real axis
if (axis < 0) {
axis += in_dims.size();
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/gpu/top_k_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ void TopkGradKernel(const Context& dev_ctx,
const T* out_grad_data = out_grad.data<T>();
const int64_t* indices_data = indices.data<int64_t>();

if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
return;
}

int pre, n, post;
phi::funcs::GetDims(in_dims, axis, &pre, &n, &post);

Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/kernels/gpu/top_k_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ void TopkKernel(const Context& dev_ctx,
const auto* input = &x;
// get the input dims
const auto& in_dims = input->dims();

// 0d input tensor
if (in_dims.size() == 0) {
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);
indices_data = 0;
yunyaoXYY marked this conversation as resolved.
Show resolved Hide resolved
return;
}
// calcluate the real axis
if (axis < 0) axis += in_dims.size();

Expand Down
112 changes: 55 additions & 57 deletions paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ void ExpandAsGradKernel(const Context& context,
const std::vector<int>& target_shape,
DenseTensor* in_grad) {
auto x_dims = x.dims();

if (in_grad->dims() == out_grad.dims()) {
phi::Copy(context, out_grad, context.GetPlace(), false, in_grad);
return;
}

auto vec_in_dims = phi::vectorize<int>(x_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
Expand All @@ -65,64 +71,56 @@ void ExpandAsGradKernel(const Context& context,
}

int dims = reduce_dims_vec.size();
bool just_copy = true;
for (size_t i = 0; i < repeat_times.size(); i++) {
if (repeat_times[i] != 1) {
just_copy = false;

PADDLE_ENFORCE_GE(
dims,
0,
errors::InvalidArgument("The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be greater than or "
"equal to 0, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(
dims,
MAX_RANK_SUPPORTED,
errors::InvalidArgument("The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
dims));
switch (dims) {
case 0:
ExpandAsBackward<Context, T, 0>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
}
}
// no need reduce, just copy
if (just_copy) {
context.template Alloc<T>(in_grad);
phi::Copy(context, out_grad, context.GetPlace(), false, in_grad);
} else {
PADDLE_ENFORCE_GE(
dims,
1,
errors::InvalidArgument("The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be greater than or "
"equal to 1, but the value received is %d.",
dims));
PADDLE_ENFORCE_LE(dims,
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The rank of the input 'Out@GRAD' for "
"expand_as_v2_grad op must be less than or equal "
"to %d, but the value received is %d.",
MAX_RANK_SUPPORTED,
dims));
switch (dims) {
case 1:
ExpandAsBackward<Context, T, 1>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 2:
ExpandAsBackward<Context, T, 2>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 3:
ExpandAsBackward<Context, T, 3>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 4:
ExpandAsBackward<Context, T, 4>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 5:
ExpandAsBackward<Context, T, 5>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 6:
ExpandAsBackward<Context, T, 6>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
case 1:
ExpandAsBackward<Context, T, 1>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 2:
ExpandAsBackward<Context, T, 2>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 3:
ExpandAsBackward<Context, T, 3>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 4:
ExpandAsBackward<Context, T, 4>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 5:
ExpandAsBackward<Context, T, 5>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
case 6:
ExpandAsBackward<Context, T, 6>(
context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad);
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"received tensor's rank = %d.",
dims));
}
}

Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/kernels/impl/expand_as_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ void ExpandAs(const Context& context,
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
if (Rank == 0) {
phi::Copy<Context>(context, x, context.GetPlace(), false, out);
return;
}
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(
target_shape[i],
Expand Down Expand Up @@ -108,7 +112,7 @@ void ExpandAsKernel(const Context& ctx,
rank));
PADDLE_ENFORCE_GE(
rank,
1,
0,
errors::InvalidArgument("The rank (%d) of the input 'x' for "
"expand_as_v2 op must be positive.",
rank));
Expand All @@ -133,6 +137,9 @@ void ExpandAsKernel(const Context& ctx,
}

switch (target_rank) {
case 0:
ExpandAs<Context, T, 0>(ctx, x, real_target_shape, out);
break;
case 1:
ExpandAs<Context, T, 1>(ctx, x, real_target_shape, out);
break;
Expand Down
Loading