Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5 No.2】Add index_fill / index_fill_ API to Paddle -part #57416

Merged
merged 15 commits into from
Nov 3, 2023
Merged
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@
unfold,
masked_fill,
masked_fill_,
index_fill,
index_fill_,
)

from .tensor.math import ( # noqa: F401
Expand Down Expand Up @@ -913,4 +915,6 @@
'masked_fill_',
'hypot',
'hypot_',
'index_fill',
"index_fill_",
]
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@
from .manipulation import unfold # noqa: F401
from .manipulation import masked_fill # noqa: F401
from .manipulation import masked_fill_ # noqa: F401
from .manipulation import index_fill # noqa: F401
from .manipulation import index_fill_ # noqa: F401
from .math import abs # noqa: F401
from .math import abs_ # noqa: F401
from .math import acos # noqa: F401
Expand Down Expand Up @@ -725,6 +727,8 @@
'asinh_',
'diag',
'normal_',
'index_fill',
'index_fill_',
]

# this list used in math_op_patch.py for magic_method bind
Expand Down
101 changes: 101 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5422,3 +5422,104 @@ def unfold(x, axis, size, step, name=None):
}
for name, func in __METHODS.items():
setattr(core.eager.Tensor, name, func)


def _index_fill_impl(x, index, axis, value, inplace):
if not isinstance(index, Variable):
raise ValueError("index must be Tensor")

if not isinstance(value, Variable):
value = paddle.to_tensor(value, dtype=x.dtype)
else:
if len(value.shape) > 0:
raise ValueError("value must be scalar or 0-D tensor")

x_dim = len(x.shape)
if not (isinstance(axis, int)) or (axis > x_dim - 1) or axis < -x_dim:
Copy link
Contributor

@jeff41404 jeff41404 Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The negative axis has been processed in L5183 above, so the judgment condition here should be axis < 0 not axis < -x_dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

raise ValueError(
"The axis should be int, and in range [-rank(x), rank(x))"
)

if axis < 0:
axis = axis + x_dim

perm = list(range(len(x.shape)))
perm[0] = axis
perm[axis] = 0

if inplace:
paddle.transpose(x, perm)
paddle.index_put_(x, (index,), value)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处,如果是inplace,是否可以不要调用clone+setitem赋值,而是直接使用index_put_赋值; 如果非inplace,是否可以不需要额外的clone操作

Copy link
Contributor

@jeff41404 jeff41404 Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the implementation solution in rfc shoule be also changed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current implementation is same with rfc API design already

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in rfc this line: out = paddle.clone(x) at else branch need to be delete ?

else:
out = paddle.transpose(x, perm)
out = paddle.index_put(out, (index,), value)
out = paddle.transpose(out, perm)
return out


def index_fill(x, index, axis, value, name=None):
"""
Outplace version of ``index_fill_`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_index_fill_`.

Examples:
.. code-block:: python

>>> import paddle
>>> input_tensor = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype='int64')
>>> index = paddle.to_tensor([0, 2], dtype="int32")
>>> value = -1
>>> res = paddle.index_fill(input_tensor, index, 0, value)
>>> print(input_tensor)
Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
>>> print(res)
Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[-1, -1, -1],
[ 4, 5, 6],
[-1, -1, -1]])

"""
return _index_fill_impl(x, index, axis, value, False)


@inplace_apis_in_dygraph_only
def index_fill_(x, index, axis, value, name=None):
"""
Fill the elements of the input tensor with value by the spcific axis and index.

Args:
x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float16, float32, float64.
index (Tensor): The 1-D Tensor containing the indices to index.
The data type of ``index`` must be int32 or int64.
axis (int): The dimension along which to index.
value (float): The tensor used to fill with.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Tensor, same dimention and dtype with x.

Examples:
.. code-block:: python

>>> import paddle
>>> input_tensor = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype='int64')
>>> index = paddle.to_tensor([0, 2], dtype="int32")
>>> value = -1
>>> res = paddle.index_fill_(input_tensor, index, 0, value)
>>> print(input_tensor)
Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[-1, -1, -1],
[ 4, 5, 6],
[-1, -1, -1]])
>>> print(res)
Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[-1, -1, -1],
[ 4, 5, 6],
[-1, -1, -1]])

"""
return _index_fill_impl(x, index, axis, value, True)
143 changes: 143 additions & 0 deletions test/legacy_test/test_index_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from itertools import combinations

import numpy as np

import paddle
from paddle.base import Program

paddle.enable_static()


def compute_index_fill_ref(x, axis, index, value):
perm = list(range(len(x.shape)))
perm[0] = axis
perm[axis] = 0

out = np.transpose(x, perm)
out[index] = value
out = np.transpose(out, perm)
return out


class TestIndexFillAPIBase(unittest.TestCase):
def setUp(self):
self.init_setting()
self.modify_setting()
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
self.index_np = np.array(self.combs[np.random.randint(0, 252)]).astype(
self.index_type
)

self.place = ['cpu']
if self.dtype_np == 'float16':
self.place = []
if paddle.is_compiled_with_cuda():
self.place.append('gpu')

def init_setting(self):
self.dtype_np = 'float64'
self.index_type = 'int64'
self.x_shape = (20, 40)
self.index_size = (5,)
self.axis = 0
self.value = -1
self.combs = list(combinations(list(range(10)), self.index_size[0]))

def modify_setting(self):
pass

def test_static_graph(self):
paddle.enable_static()
for place in self.place:
with paddle.static.program_guard(Program()):
x = paddle.static.data(
name="x", shape=self.x_shape, dtype=self.dtype_np
)
index = paddle.static.data(
name="index", shape=self.index_size, dtype=self.index_type
)
out = paddle.index_fill(x, index, self.axis, self.value)
exe = paddle.static.Executor(place=place)
feed_list = {"x": self.x_np, "index": self.index_np}
pd_res = exe.run(
paddle.static.default_main_program(),
feed=feed_list,
fetch_list=[out],
)[0]
ref_res = compute_index_fill_ref(
self.x_np, self.axis, self.index_np, self.value
)
np.testing.assert_allclose(ref_res, pd_res)

def test_dygraph(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
x_pd = paddle.to_tensor(self.x_np)
index_pd = paddle.to_tensor(self.index_np)
pd_res = paddle.index_fill(x_pd, index_pd, self.axis, self.value)
ref_res = compute_index_fill_ref(
self.x_np, self.axis, self.index_np, self.value
)
np.testing.assert_allclose(ref_res, pd_res)

def test_errors(self):
data_np = np.random.random((10, 10)).astype(np.float32)
index = paddle.to_tensor([0, 2])

def test_index_not_tensor():
res = paddle.index_fill(data_np, [0, 2], axis=-1, value=-1)

self.assertRaises(ValueError, test_index_not_tensor)

def test_value_shape():
res = paddle.index_fill(
data_np, index, axis=-1, value=paddle.to_tensor([-1, -4])
)

self.assertRaises(ValueError, test_value_shape)

def test_axis_range():
res = paddle.index_fill(data_np, index, axis=4, value=-1)

self.assertRaises(ValueError, test_axis_range)


class TestIndexFillAPI1(TestIndexFillAPIBase):
def modify_setting(self):
self.dtype_np = 'int64'
self.index_type = 'int32'
self.x_shape = (10, 15, 10)
self.axis = 1


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充下complex类型的测试吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_put不支持complex类型的输入

增加了float16类型的测试

class TestIndexFillAPI2(TestIndexFillAPIBase):
def modify_setting(self):
self.dtype_np = 'bool'
self.index_type = 'int32'
self.x_shape = (10, 15, 10)
self.axis = 1
self.value = True


class TestIndexFillAPI3(TestIndexFillAPIBase):
def modify_setting(self):
self.dtype_np = 'float16'
self.x_shape = (10, 15, 10)
self.axis = 1
self.value = 0.5
15 changes: 15 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,5 +1601,20 @@ def test_forward_version(self):
self.assertEqual(var.inplace_version, 2)


class TestDygraphInplaceIndexFill(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.random((20, 40))
self.dtype = "float32"
self.axis = 0
self.index = paddle.to_tensor([0, 2])
self.value = -1

def inplace_api_processing(self, var):
return paddle.index_fill_(var, self.index, self.axis, self.value)

def non_inplace_api_processing(self, var):
return paddle.index_fill(var, self.index, self.axis, self.value)


if __name__ == '__main__':
unittest.main()