Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] support input 0D Tensor for min/max/amin/amax/prod/logsumexp/all/any #47501

Merged
merged 1 commit into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 8 additions & 78 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1805,84 +1805,14 @@ void LogsumexpInferMeta(const MetaTensor& input,
bool keepdim,
bool reduce_all,
MetaTensor* out) {
auto x_dims = input.dims();
auto x_rank = x_dims.size();
std::vector<int64_t> formated_axis = axis;
PADDLE_ENFORCE_LE(x_rank,
4,
errors::InvalidArgument(
"The input tensor X's dimensions of logsumexp "
"should be less or equal than 4. But received X's "
"dimensions = %d, X's shape = [%s].",
x_rank,
x_dims));
PADDLE_ENFORCE_GT(
axis.size(),
0,
errors::InvalidArgument(
"The size of axis of logsumexp "
"should be greater than 0. But received the size of axis "
"of logsumexp is %d.",
axis.size()));

for (size_t i = 0; i < axis.size(); i++) {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i,
x_rank,
i,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"axis[%d] should be in the "
"range [-D, D), where D is the dimensions of X and "
"D is %d. But received axis[%d] = %d.",
i,
x_rank,
i,
axis[i]));
if (axis[i] < 0) {
formated_axis[i] += x_rank;
}
}

auto dims_vector = vectorize(x_dims);
if (reduce_all) {
if (keepdim)
out->set_dims(phi::make_ddim(std::vector<int64_t>(x_rank, 1)));
else
out->set_dims({1});
} else {
auto dims_vector = vectorize(x_dims);
if (keepdim) {
for (size_t i = 0; i < formated_axis.size(); ++i) {
dims_vector[formated_axis[i]] = 1;
}
} else {
const int kDelFlag = -1;
for (size_t i = 0; i < formated_axis.size(); ++i) {
dims_vector[formated_axis[i]] = kDelFlag;
}
dims_vector.erase(
std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
if (!keepdim && dims_vector.size() == 0) {
dims_vector.push_back(1);
}
auto out_dims = phi::make_ddim(dims_vector);
out->set_dims(out_dims);
if (formated_axis.size() > 0 && formated_axis[0] != 0) {
// Only pass LoD when not reducing on the first dim.
out->share_lod(input);
}
}
out->set_dtype(input.dtype());
auto input_rank = input.dims().size();
// only supoort 0~4D, due to eigen template compile slow
PADDLE_ENFORCE_LE(
input_rank,
4,
errors::InvalidArgument("The input tensor X's dimensions of logsumexp "
"should be less or equal than 4. "));
ReduceInferMetaBase(input, axis, keepdim, reduce_all, out);
}

void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ void LogsumexpGradKernel(const Context& dev_ctx,
DenseTensor* in_grad) {
dev_ctx.template Alloc<T>(in_grad);

const auto input_dim_size = in.dims().size();
reduce_all |= (static_cast<int>(axis.size()) == input_dim_size);
if (axis.size() == 0 || static_cast<int>(axis.size()) == in.dims().size()) {
reduce_all = true;
}

if (reduce_all) {
auto x = phi::EigenVector<T>::Flatten(in);
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/impl/logsumexp_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ void LogsumexpKernel(const Context& dev_ctx,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);

const auto& input_dim_size = x.dims().size();
// The dims has full dim, set the reduce_all is True
reduce_all |= (static_cast<int>(axis.size()) == input_dim_size);
if (axis.size() == 0 || static_cast<int>(axis.size()) == x.dims().size()) {
reduce_all = true;
}

if (reduce_all) {
// Flatten and reduce 1-D tensor
Expand All @@ -81,7 +81,7 @@ void LogsumexpKernel(const Context& dev_ctx,
auto reduce_dim = Eigen::array<int, 1>({{0}});
LogsumexpFunctor<T>()(place, &input, &output, reduce_dim);
} else {
int ndim = input_dim_size;
int ndim = x.dims().size();
int rdim = axis.size();
if (ndim > 4) {
PADDLE_THROW(phi::errors::Unimplemented(
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/reduce_any_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ void AnyKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
AnyRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}

Expand Down
6 changes: 6 additions & 0 deletions python/paddle/fluid/tests/unittests/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ def calc_grad(self):
return dy * np.exp(x - y)


class TestLogsumexp_ZeroDim(TestLogsumexp):
def set_attrs(self):
self.shape = []
self.axis = []


class TestLogsumexp_shape(TestLogsumexp):
def set_attrs(self):
self.shape = [4, 5, 6]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def _test_dygraph(func):
# test two minimum or maximum elements


class TestMaxMinAmaxAminAPI_ZeroDim(TestMaxMinAmaxAminAPI):
def init_case(self):
self.x_np = np.array(0.5)
self.shape = []
self.dtype = 'float64'
self.axis = None
self.keepdim = False


class TestMaxMinAmaxAminAPI2(TestMaxMinAmaxAminAPI):
def init_case(self):
self.x_np = np.array([[0.2, 0.3, 0.9, 0.9], [0.1, 0.1, 0.6, 0.7]])
Expand Down
71 changes: 71 additions & 0 deletions python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,22 @@ def test_check_output(self):
self.check_output(check_eager=True)


class TestMaxOp_ZeroDim(OpTest):
"""Remove Max with subgradient from gradient check to confirm the success of CI."""

def setUp(self):
self.op_type = "reduce_max"
self.python_api = paddle.max
self.inputs = {'X': np.random.random([]).astype("float64")}
self.attrs = {'dim': []}
self.outputs = {
'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim']))
}

def test_check_output(self):
self.check_output(check_eager=True)


@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework."
Expand All @@ -237,6 +253,22 @@ def test_check_output(self):
self.check_output(check_eager=True)


class TestMinOp_ZeroDim(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""

def setUp(self):
self.op_type = "reduce_min"
self.python_api = paddle.min
self.inputs = {'X': np.random.random([]).astype("float64")}
self.attrs = {'dim': []}
self.outputs = {
'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
}

def test_check_output(self):
self.check_output(check_eager=True)


class TestMin6DOp(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""

Expand Down Expand Up @@ -297,6 +329,21 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)


class TestProdOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.prod
self.op_type = "reduce_prod"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}

def test_check_output(self):
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)


class TestProd6DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
Expand Down Expand Up @@ -361,6 +408,18 @@ def test_check_output(self):
self.check_output(check_eager=True)


class TestAllOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.all
self.op_type = "reduce_all"
self.inputs = {'X': np.random.randint(0, 2, []).astype("bool")}
self.outputs = {'Out': self.inputs['X'].all()}
self.attrs = {'dim': [], 'reduce_all': True}

def test_check_output(self):
self.check_output(check_eager=True)


class TestAll8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_all"
Expand Down Expand Up @@ -464,6 +523,18 @@ def test_check_output(self):
self.check_output(check_eager=True)


class TestAnyOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.any
self.op_type = "reduce_any"
self.inputs = {'X': np.random.randint(0, 2, []).astype("bool")}
self.outputs = {'Out': self.inputs['X'].any()}
self.attrs = {'dim': [], 'reduce_all': True}

def test_check_output(self):
self.check_output(check_eager=True)


class TestAny8DOp(OpTest):
def setUp(self):
self.op_type = "reduce_any"
Expand Down
36 changes: 26 additions & 10 deletions python/paddle/fluid/tests/unittests/test_zero_dim_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ def test_static_unary(self):
paddle.mean,
paddle.nansum,
paddle.nanmean,
paddle.min,
paddle.max,
paddle.amin,
paddle.amax,
paddle.prod,
paddle.logsumexp,
paddle.all,
paddle.any,
]


Expand All @@ -173,15 +181,21 @@ def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for api in reduce_api_list:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out.backward()
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, []).astype('bool')
out = api(x, None)
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
else:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out.backward()

self.assertEqual(x.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])

paddle.enable_static()

Expand All @@ -190,11 +204,13 @@ def test_static(self):
for api in reduce_api_list:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, []).astype('bool')
else:
x = paddle.rand([])

x.stop_gradient = False
out = api(x, None)
fluid.backward.append_backward(out)

# Test compile shape, grad is always [1]
self.assertEqual(x.shape, ())
Expand Down
Loading