Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Fix 0-dim tensor for arg_min_max op. #49570

Merged
merged 15 commits into from
Feb 1, 2023
47 changes: 32 additions & 15 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,34 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims();

PADDLE_ENFORCE_GE(
int_axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(int_axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_dims.size()));
auto x_rank = x.dims().size();
if (x_rank > 0) {
PADDLE_ENFORCE_GE(int_axis,
-x_rank,
phi::errors::InvalidArgument(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_rank));
PADDLE_ENFORCE_LT(
int_axis,
x_rank,
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_rank));
} else {
// 0-dim tensor
PADDLE_ENFORCE_EQ((int_axis == 0 || int_axis == -1) && flatten,
true,
phi::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is "
"0-dim. and flatten should be true.",
int_axis));
}

auto x_rank = x_dims.size();
if (int_axis < 0) int_axis += x_rank;

if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0;
Expand All @@ -195,8 +207,12 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
INT_MAX));
}
}

std::vector<int64_t> vec;
if (flatten) {

if (x_rank == 0) {
// vec is set to empty
} else if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]);
Expand All @@ -205,6 +221,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
}
for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}

out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/cpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ struct VisitDataArgMinMaxFunctor {
if (axis < 0) new_axis = axis + x_dims.size();
}

// For 0D Tensor
if (x.dims().size() == 0) {
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}

#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/gpu/arg_min_max_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace cub = hipcub;

#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {

namespace { // NOLINT
Expand Down Expand Up @@ -180,6 +181,12 @@ struct VisitDataCudaArgMinMaxFunctor {
x_dims = x.dims();
if (axis < 0) new_axis = axis + x.dims().size();
}
// For 0D Tensor
if (x.dims().size() == 0) {
dev_ctx.template Alloc<IndType>(out);
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}

int64_t numel = x.numel();
int64_t groups = numel / x_dims[new_axis];
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/xpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand All @@ -39,7 +40,15 @@ void ArgMaxKernel(const Context& dev_ctx,
DataType::INT64,
DataType::INT32,
dtype));
// TODO(ZHUI): fix dtype of out
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out->data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
return;
}

DDim x_dims;
int axis_val = axis.to<int>();
Expand Down
15 changes: 10 additions & 5 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def test_static_unary(self):
paddle.logsumexp,
paddle.all,
paddle.any,
paddle.argmax,
paddle.argmin,
]


Expand All @@ -208,12 +210,13 @@ def test_dygraph_reduce(self):
out.retain_grads()
out.backward()

out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out)

self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out.numpy(), x.numpy())
if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(out.numpy(), x.numpy())
out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out)

if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
Expand Down Expand Up @@ -250,7 +253,9 @@ def test_static_reduce(self):
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[0], res[1])
if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(res[0], res[1])

if len(res) > 2:
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def test_dygraph_unary(self):
paddle.logsumexp,
paddle.all,
paddle.any,
paddle.argmax,
paddle.argmin,
]


Expand All @@ -153,7 +155,8 @@ def test_dygraph_reduce(self):

self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out.numpy(), x.numpy())
if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(out.numpy(), x.numpy())
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
Expand Down