From 501e520e503b689dcb7c1cefe410437e44a4d7f9 Mon Sep 17 00:00:00 2001 From: JYChen Date: Mon, 25 Dec 2023 15:11:10 +0800 Subject: [PATCH 1/5] Fix set value grad (#59034) * first fix the UT * fix set value grad * polish code * add static mode backward test * always has input valuetensor * add dygraph test --- paddle/fluid/operators/set_value_op.cc | 44 +++++----- paddle/phi/api/yaml/legacy_backward.yaml | 6 +- .../phi/kernels/cpu/set_value_grad_kernel.cc | 17 ++++ .../phi/kernels/gpu/set_value_grad_kernel.cu | 17 ++++ .../kernels/impl/set_value_grad_kernel_impl.h | 22 +++++ paddle/phi/kernels/set_value_grad_kernel.h | 10 +++ .../phi/kernels/xpu/set_value_grad_kernel.cc | 31 +++++++ test/legacy_test/test_set_value_op.py | 82 +++++++++++++++++++ 8 files changed, 201 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index 16864b80b5c76..a0aa1f589191f 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -151,32 +151,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr op) const override { - if (this->HasInput("ValueTensor")) { - op->SetType("set_value_grad"); - - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetInput("ValueTensor", this->Input("ValueTensor")); - if (this->HasInput("StartsTensorList")) { - op->SetInput("StartsTensorList", this->Input("StartsTensorList")); - } - if (this->HasInput("EndsTensorList")) { - op->SetInput("EndsTensorList", this->Input("EndsTensorList")); - } - if (this->HasInput("StepsTensorList")) { - op->SetInput("StepsTensorList", this->Input("StepsTensorList")); - } - - op->SetAttrMap(this->Attrs()); - - op->SetOutput(framework::GradVarName("ValueTensor"), - this->InputGrad("ValueTensor")); - op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); - - } else { - op->SetType("assign"); - op->SetInput("X", this->OutputGrad("Out")); - op->SetOutput("Out", this->InputGrad("Input")); + op->SetType("set_value_grad"); + op->SetInput("ValueTensor", this->Input("ValueTensor")); + op->SetOutput(framework::GradVarName("ValueTensor"), + this->InputGrad("ValueTensor")); + + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + if (this->HasInput("StartsTensorList")) { + op->SetInput("StartsTensorList", this->Input("StartsTensorList")); + } + if (this->HasInput("EndsTensorList")) { + op->SetInput("EndsTensorList", this->Input("EndsTensorList")); } + if (this->HasInput("StepsTensorList")) { + op->SetInput("StepsTensorList", this->Input("StepsTensorList")); + } + + op->SetAttrMap(this->Attrs()); + + op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); } }; diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 04cf57a88bb7c..3f11781dfe88e 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -614,14 +614,14 @@ - backward_op : set_value_grad forward : set_value (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) -> Tensor(out) - args : (Tensor out_grad) + args : (Tensor out_grad, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) output : Tensor(x_grad) infer_meta: func: UnchangedInferMeta param: [out_grad] kernel: - func: assign - param: [out_grad] + func: set_value_with_scalar_grad + param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes] - backward_op : set_value_with_tensor_grad forward: set_value_with_tensor (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) -> Tensor(out) diff --git a/paddle/phi/kernels/cpu/set_value_grad_kernel.cc b/paddle/phi/kernels/cpu/set_value_grad_kernel.cc index ed35513d98550..237a892dbb356 100644 --- a/paddle/phi/kernels/cpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/set_value_grad_kernel.cc @@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad, phi::dtype::float16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(set_value_with_scalar_grad, + CPU, + ALL_LAYOUT, + phi::SetValueWithScalarGradKernel, + float, + double, + int, + int64_t, + bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::bfloat16, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/set_value_grad_kernel.cu b/paddle/phi/kernels/gpu/set_value_grad_kernel.cu index 66688b417ae30..42ff5b912eccd 100644 --- a/paddle/phi/kernels/gpu/set_value_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/set_value_grad_kernel.cu @@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(set_value_with_scalar_grad, + GPU, + ALL_LAYOUT, + phi::SetValueWithScalarGradKernel, + float, + double, + int, + int64_t, + bool, + int16_t, + uint8_t, + int8_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h b/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h index 3f78361b92b8b..99f05f80c17ff 100644 --- a/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h @@ -341,4 +341,26 @@ void SetValueGradKernel(const Context& dev_ctx, } } +template +void SetValueWithScalarGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* x_grad) { + SetValueGradKernel(dev_ctx, + out_grad, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + x_grad, + nullptr); +} + } // namespace phi diff --git a/paddle/phi/kernels/set_value_grad_kernel.h b/paddle/phi/kernels/set_value_grad_kernel.h index e4dad683e40a9..04592cd2002d1 100644 --- a/paddle/phi/kernels/set_value_grad_kernel.h +++ b/paddle/phi/kernels/set_value_grad_kernel.h @@ -32,4 +32,14 @@ void SetValueGradKernel(const Context& dev_ctx, DenseTensor* x_grad, DenseTensor* value_grad); +template +void SetValueWithScalarGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc index d1ad332cd626c..c5d33ae4ac8d0 100644 --- a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc @@ -397,6 +397,28 @@ void SetValueGradKernel(const Context& dev_ctx, } } +template +void SetValueWithScalarGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const IntArray& starts, + const IntArray& ends, + const IntArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* x_grad) { + SetValueGradKernel(dev_ctx, + out_grad, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + x_grad, + nullptr); +} + } // namespace phi PD_REGISTER_KERNEL(set_value_grad, @@ -407,3 +429,12 @@ PD_REGISTER_KERNEL(set_value_grad, phi::dtype::float16, int, int64_t) {} + +PD_REGISTER_KERNEL(set_value_with_scalar_grad, + XPU, + ALL_LAYOUT, + phi::SetValueWithScalarGradKernel, + float, + phi::dtype::float16, + int, + int64_t) {} diff --git a/test/legacy_test/test_set_value_op.py b/test/legacy_test/test_set_value_op.py index 65c9f69765d11..c42026fb9caee 100644 --- a/test/legacy_test/test_set_value_op.py +++ b/test/legacy_test/test_set_value_op.py @@ -1978,5 +1978,87 @@ def test_check_grad(self): self.check_grad_with_place(place, ['Input'], 'Out', check_dygraph=False) +class TestSetValueWithScalarInStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.shape = (10, 2) + self.exe = paddle.static.Executor() + self.train_program = paddle.static.Program() + self.startup_program = paddle.static.Program() + + def test_value_input_is_scalar(self): + with paddle.static.program_guard( + self.train_program, self.startup_program + ): + x = paddle.ones(self.shape) + x.stop_gradient = False + y = x * 1 + + # mock test case x[0, 0] = 10 with no ValueTensor input + inputs = { + 'Input': y, + } + attrs = { + 'axes': [0, 1], + 'starts': [0, 0], + 'ends': [1, 1], + 'steps': [1, 1], + 'values': [10], + 'shape': [1], + } + + helper = LayerHelper("set_value") + out = helper.create_variable_for_type_inference(dtype=y.dtype) + + helper.append_op( + type="set_value", + inputs=inputs, + outputs={'Out': out}, + attrs=attrs, + ) + + np_data = np.ones(self.shape).astype('float32') + + paddle.static.append_backward(out.sum()) + res = self.exe.run( + self.train_program, fetch_list=[out, x.grad_name] + ) + + np_data[0, 0] = 10 + expected_x_grad = np.ones(self.shape) + expected_x_grad[0, 0] = 0 + + np.testing.assert_array_equal(res[0], np_data) + np.testing.assert_array_equal(res[1], expected_x_grad) + + +class TestSetValueWithScalarInDygraph(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.shape = (10, 2) + + def test_value_input_is_scalar(self): + x = paddle.ones(self.shape) + x.stop_gradient = False + y = x * 1 + + # mock test case x[0, 0] = 10 with no ValueTensor input + out = paddle._C_ops.set_value( + y, [0, 0], [1, 1], [1, 1], [0, 1], [], [], [1], [10.0] + ) + + loss = out.sum() + loss.backward() + + np_data = np.ones(self.shape).astype('float32') + np_data[0, 0] = 10 + + expected_x_grad = np.ones(self.shape) + expected_x_grad[0, 0] = 0 + + np.testing.assert_array_equal(out, np_data) + np.testing.assert_array_equal(x.grad, expected_x_grad) + + if __name__ == '__main__': unittest.main() From 94e5da3256cc8c93d688bd8441d206ceb9a2095c Mon Sep 17 00:00:00 2001 From: JYChen Date: Tue, 2 Jan 2024 12:07:21 +0800 Subject: [PATCH 2/5] Fix shape error in combined-indexing setitem (#60447) * add ut * fix shape error in combine-indexing * fix ut --- paddle/fluid/pybind/eager_method.cc | 16 ++- paddle/fluid/pybind/slice_utils.h | 43 ++++----- python/paddle/base/variable_index.py | 45 ++++++--- test/indexing/test_setitem.py | 139 +++++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 5effab997848d..37e1e80774d22 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1375,7 +1375,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, // step3: Dealing with advanced indexing std::vector transed_index; - std::vector trans_back_dim; + std::vector trans_back_dim, trans_dim; int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1; paddle::Tensor transed_tensor = dealWithAdvancedIndex(out, @@ -1385,7 +1385,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, &transed_index, &trans_back_dim, &pos_of_new_dim, - &rank_of_new_dim); + &rank_of_new_dim, + &trans_dim); if (transed_index.size() == 1 && transed_index[0].dtype() == phi::DataType::BOOL) { @@ -1679,9 +1680,9 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &use_strided_slice); std::vector transed_index; - std::vector trans_back_dim; + std::vector trans_back_dim, trans_dim; - int pos_of_new_dim = 0, rank_of_new_dim = 0; + int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1; paddle::Tensor transed_sub_tensor = dealWithAdvancedIndex(sub_tensor, @@ -1691,7 +1692,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &transed_index, &trans_back_dim, &pos_of_new_dim, - &rank_of_new_dim); + &rank_of_new_dim, + &trans_dim); // Release gil and do tracing py::gil_scoped_release release; @@ -1714,6 +1716,10 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, } } + if (value_tensor.dims().size() > 1 && pos_of_new_dim != 0) { + value_tensor = transpose_ad_func(value_tensor, trans_dim); + } + // TODO(zoooo0820) 1.Using inplace version index_put // 2.Remove following code after backward bug fixed. transed_sub_tensor = assign_ad_func(transed_sub_tensor); diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index 918d2eeae4272..73ad179d6782a 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -397,9 +397,8 @@ static paddle::Tensor dealWithAdvancedIndex( std::vector* transed_index, std::vector* trans_back_dim, int* pos_of_new_dim, - int* rank_of_new_dim) { - std::vector trans_dim; - + int* rank_of_new_dim, + std::vector* trans_dim) { int p = 0; for (size_t i = 0; i < advanced_index_dim->size(); ++i) { auto index_dim = (*advanced_index_dim)[i]; @@ -408,30 +407,28 @@ static paddle::Tensor dealWithAdvancedIndex( // advanced_index_dim auto index = (*advanced_index)[p++]; - if (!is_for_setitem) { - if (index_dim == 0) { - // case 1: advanced indices at axis 0, the new dim will be at first. - *pos_of_new_dim = 0; - } else if (index_dim > 0 && trans_dim.size() > 0 && - trans_dim[trans_dim.size() - 1] != index_dim - 1) { - // case 2: there are not adjacent advanced indices, the new dim will - // be at first. - *pos_of_new_dim = 0; - } else { - *pos_of_new_dim = std::min(index_dim, *pos_of_new_dim); - } - *rank_of_new_dim = - std::max(*rank_of_new_dim, static_cast(index.shape().size())); + if (index_dim == 0) { + // case 1: advanced indices at axis 0, the new dim will be at first. + *pos_of_new_dim = 0; + } else if (index_dim > 0 && trans_dim->size() > 0 && + (*trans_dim)[trans_dim->size() - 1] != index_dim - 1) { + // case 2: there are not adjacent advanced indices, the new dim will + // be at first. + *pos_of_new_dim = 0; + } else { + *pos_of_new_dim = std::min(index_dim, *pos_of_new_dim); } + *rank_of_new_dim = + std::max(*rank_of_new_dim, static_cast(index.shape().size())); - trans_dim.push_back(index_dim); + trans_dim->push_back(index_dim); transed_index->push_back(std::move(index)); } } for (size_t i = 0; i < tensor.shape().size(); ++i) { if ((*advanced_index_dim)[i] == -1) { - trans_dim.push_back(i); + trans_dim->push_back(i); } } @@ -441,19 +438,19 @@ static paddle::Tensor dealWithAdvancedIndex( std::vector original_dim_order(tensor.shape().size()); std::iota(original_dim_order.begin(), original_dim_order.end(), 0); - if (original_dim_order == trans_dim) { + if (original_dim_order == *trans_dim) { transed_tensor = tensor; } else { - transed_tensor = transpose_ad_func(tensor, trans_dim); + transed_tensor = transpose_ad_func(tensor, *trans_dim); } if (is_for_setitem) { - trans_back_dim->resize(trans_dim.size()); + trans_back_dim->resize(trans_dim->size()); std::iota(trans_back_dim->begin(), trans_back_dim->end(), 0); std::sort(trans_back_dim->begin(), trans_back_dim->end(), [&trans_dim](int left, int right) { - return trans_dim[left] < trans_dim[right]; + return (*trans_dim)[left] < (*trans_dim)[right]; }); } return transed_tensor; diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index f3a04076ef3fb..5cb0142db92a1 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -191,7 +191,7 @@ def _setitem_for_tensor_array(var, item, value): ) -def deal_advanced_index(ori_tensor, indices, is_for_setitem): +def deal_advanced_index(ori_tensor, indices, is_for_setitem, values): """ Transpose origin Tensor and advanced indices to the front. @@ -201,6 +201,7 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem): trans_back_dim (List): order of axes to transpose back to original order. Only used in __setitem__. pos_of_new_dim (int): axis of new dim in the result. Only used in __getitem__. rank_of_new_dim (int): rank of new dim in the result. Only used in __getitem__. + transed_value_tensor (Tensor): value tensor transed to the front. Only used in __setitem__. """ transed_dim = [] transed_index = [] @@ -212,16 +213,15 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem): for i, indice in enumerate(indices): if indice is not None: - if not is_for_setitem: - if i == 0: - # case 1: advanced indices at axis 0, the new dim will be at first. - pos_of_new_dim = 0 - if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1: - # case 2: there are not adjacent advanced indices, the new dim will be at first. - pos_of_new_dim = 0 - else: - pos_of_new_dim = min(pos_of_new_dim, i) - rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim) + if i == 0: + # case 1: advanced indices at axis 0, the new dim will be at first. + pos_of_new_dim = 0 + if i > 0 and len(transed_dim) > 0 and transed_dim[-1] != i - 1: + # case 2: there are not adjacent advanced indices, the new dim will be at first. + pos_of_new_dim = 0 + else: + pos_of_new_dim = min(pos_of_new_dim, i) + rank_of_new_dim = max(rank_of_new_dim, indice[1].ndim) transed_dim.append(i) transed_index.append(indice[1]) for i in range(ori_tensor.ndim): @@ -231,12 +231,22 @@ def deal_advanced_index(ori_tensor, indices, is_for_setitem): trans_back_dim = np.argsort(transed_dim).tolist() if is_for_setitem else [] + transed_value_tensor = None + if is_for_setitem: + if values.ndim > 1 and pos_of_new_dim != 0: + # If the value tensor is not a scalar / 1-D Tensor, and the src tensor was + # transposed at 1st dim, the value tensor should be transposed too. + transed_value_tensor = values.transpose(transed_dim) + else: + transed_value_tensor = values + return ( transed_tensor, transed_index, trans_back_dim, pos_of_new_dim, rank_of_new_dim, + transed_value_tensor, ) @@ -550,6 +560,11 @@ def _setitem_static(x, indices, values): # 3. assign values to the sliced result by index_put OP; # 4. transpose back and assign the result to original tensor by set_value OP. + if not isinstance( + values, (Variable, paddle.pir.Value, paddle.pir.OpResult) + ): + values = paddle.assign(values).astype(x.dtype) + sub_tensor = get_tensor_with_basic_indexing( x, axes, @@ -566,9 +581,8 @@ def _setitem_static(x, indices, values): transback_dim, _, _, - ) = deal_advanced_index(sub_tensor, advanced_index, True) - if not isinstance(values, (Variable, paddle.pir.Value)): - values = paddle.assign(values).astype(transed_sub_tensor.dtype) + values, + ) = deal_advanced_index(sub_tensor, advanced_index, True, values) if values.dtype != transed_sub_tensor.dtype: values = values.astype(transed_sub_tensor.dtype) @@ -769,7 +783,8 @@ def _getitem_static(x, indices): _, pos_of_new_dim, rank_of_new_dim, - ) = deal_advanced_index(out, advanced_index, False) + _, + ) = deal_advanced_index(out, advanced_index, False, None) # TODO(zooooo0820): Replacing gather_nd to another advanded OP for handling of mixed indexes more efficiently if ( diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index b8d7e3361efc4..8e1b0bbe72a04 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -228,6 +228,79 @@ def test_indexing_is_boolean_false(self): np.testing.assert_allclose(x.numpy(), np_data) + def test_combined_indexing_and_value_is_tensor_1(self): + # value is tensor with same shape to getitem and index will be adjusted + np_data = np.ones((3, 3)).astype(self.ndtype) + value_data = np.array([-1, -1, -1]).astype(self.ndtype) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + value_data = convert_uint16_to_float( + convert_float_to_uint16(value_data) + ) + if self.dtype == 'complex64' or self.dtype == 'complex128': + np_data = np_data + 1j * np_data + value_data = value_data + 1j * value_data + + x = paddle.to_tensor(np_data, dtype=self.dtype) + v = paddle.to_tensor(value_data, dtype=self.dtype) + + np_data[:, [0, 2]] = np_data[:, [0, 2]] + np.expand_dims(value_data, -1) + x[:, [0, 2]] = x[:, [0, 2]] + v.unsqueeze(-1) + + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') + + np.testing.assert_allclose(x.numpy(), np_data) + + def test_combined_indexing_and_value_is_tensor_2(self): + # value is tensor needed to broadcast and index will be adjusted + np_data = np.ones((3, 4, 5, 6)).astype(self.ndtype) + value_data = np.arange(3 * 4 * 2 * 1).reshape((3, 4, 2, 1)) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + value_data = convert_uint16_to_float( + convert_float_to_uint16(value_data) + ) + if self.dtype == 'complex64' or self.dtype == 'complex128': + np_data = np_data + 1j * np_data + value_data = value_data + 1j * value_data + + x = paddle.to_tensor(np_data, dtype=self.dtype) + v = paddle.to_tensor(value_data, dtype=self.dtype) + x[..., [1, 4], ::2] = v + + np_data[..., [1, 4], ::2] = value_data + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') + np.testing.assert_allclose(x.numpy(), np_data) + + def test_combined_indexing_and_value_is_tensor_3(self): + # value is tensor and index will be adjusted + # and the value rank is less than original tensor + np_data = np.ones((3, 4, 5, 6)).astype(self.ndtype) + value_data = np.arange(2 * 3 * 5).reshape((2, 3, 5)) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + value_data = convert_uint16_to_float( + convert_float_to_uint16(value_data) + ) + if self.dtype == 'complex64' or self.dtype == 'complex128': + np_data = np_data + 1j * np_data + value_data = value_data + 1j * value_data + + x = paddle.to_tensor(np_data, dtype=self.dtype) + v = paddle.to_tensor(value_data, dtype=self.dtype) + x[:, [1, 3], :, [3, 4]] = v + + np_data[:, [1, 3], :, [3, 4]] = value_data + + if self.dtype == 'bfloat16': + x = paddle.cast(x, dtype='float32') + np.testing.assert_allclose(x.numpy(), np_data) + def test_inplace_with_stride(self): np_v = np.random.randn(3, 1).astype(self.ndtype) if self.dtype == 'bfloat16': @@ -574,6 +647,72 @@ def test_indexing_is_boolean_false(self): np.testing.assert_allclose(res[0], np_data) + @test_with_pir_api + def test_combined_indexing_and_value_is_tensor_1(self): + # value is tensor with same shape to getitem and index will be adjusted + np_data = np.ones((3, 3), dtype='int32') + value_data = np.array([-1, -1, -1]) + np_data[:, [0, 2]] = np_data[:, [0, 2]] * np.expand_dims(value_data, -1) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 3), dtype='int32') + v = paddle.to_tensor([-1, -1, -1]) + y = _setitem_static( + x, + (slice(None), [0, 2]), + x[:, [0, 2]] * v.unsqueeze(-1), + ) + res = self.exe.run(fetch_list=[y]) + + np.testing.assert_allclose(res[0], np_data) + + @test_with_pir_api + def test_combined_indexing_and_value_is_tensor_2(self): + # value is tensor needed to broadcast and index will be adjusted + np_data = np.ones((3, 4, 5, 6), dtype='int32') + value_data = np.arange(3 * 4 * 2 * 1).reshape((3, 4, 2, 1)) + np_data[..., [1, 4], ::2] = value_data + + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + v = paddle.arange(3 * 4 * 2 * 1).reshape((3, 4, 2, 1)) + + y = _setitem_static( + x, + (..., [1, 4], slice(None, None, 2)), + v, + ) + + res = self.exe.run(fetch_list=[y]) + + np.testing.assert_allclose(res[0], np_data) + + @test_with_pir_api + def test_combined_indexing_and_value_is_tensor_3(self): + # value is tensor and index will be adjusted + # and the value rank is less than original tensor + np_data = np.ones((3, 4, 5, 6), dtype='int32') + value_data = np.arange(2 * 3 * 5).reshape((2, 3, 5)) + np_data[:, [1, 3], :, [3, 4]] = value_data + + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.ones((3, 4, 5, 6), dtype='int32') + v = paddle.arange(2 * 3 * 5).reshape((2, 3, 5)) + y = _setitem_static( + x, + (slice(None), [1, 3], slice(None), [3, 4]), + v, + ) + + res = self.exe.run(fetch_list=[y]) + + np.testing.assert_allclose(res[0], np_data) + if __name__ == '__main__': unittest.main() From 5601f81b835dc59ebfcf208a0ca18020e67cd9d3 Mon Sep 17 00:00:00 2001 From: JYChen Date: Tue, 2 Jan 2024 15:28:16 +0800 Subject: [PATCH 3/5] Set value with scalar (#60452) * set_value with scalar * fix ut --- paddle/fluid/pybind/eager_method.cc | 92 +++++++++------- paddle/fluid/pybind/slice_utils.h | 101 ++++++++++++++++++ .../base/dygraph/tensor_patch_methods.py | 11 +- 3 files changed, 157 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 37e1e80774d22..1bfc34a2bed10 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1608,12 +1608,9 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, &use_strided_slice); // step2: Parse values - PADDLE_ENFORCE( - PyCheckTensor(value_obj), - platform::errors::InvalidArgument("The value must be a Tensor")); - + std::vector values; paddle::Tensor value_tensor = - reinterpret_cast(value_obj)->tensor; + dealWithValues(tensor, value_obj, &values, has_advanced_index); if (!has_advanced_index) { // use set_value OP if there is no advanced index @@ -1621,45 +1618,60 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, // Release gil and do tracing py::gil_scoped_release release; // use inplace set_value_ operator - if (value_tensor.initialized() && - (self->tensor.dtype() != value_tensor.dtype())) { - if (egr::Controller::Instance().GetAMPLevel() != - paddle::imperative::AmpLevel::O0) { - paddle::small_vector, - egr::kSlotSmallVectorSize> - tmps = {{self->tensor}, {value_tensor}}; - auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps); - self->tensor = egr::EagerAmpAutoCast( - self->tensor.name(), self->tensor, amp_dtype, "set_value"); - value_tensor = egr::EagerAmpAutoCast( - value_tensor.name(), value_tensor, amp_dtype, "set_value"); - } + if (value_tensor.initialized()) { if (self->tensor.dtype() != value_tensor.dtype()) { - value_tensor = cast_ad_func(value_tensor, self->tensor.dtype()); + if (egr::Controller::Instance().GetAMPLevel() != + paddle::imperative::AmpLevel::O0) { + paddle::small_vector, + egr::kSlotSmallVectorSize> + tmps = {{self->tensor}, {value_tensor}}; + auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps); + self->tensor = egr::EagerAmpAutoCast( + self->tensor.name(), self->tensor, amp_dtype, "set_value"); + value_tensor = egr::EagerAmpAutoCast( + value_tensor.name(), value_tensor, amp_dtype, "set_value"); + } + if (self->tensor.dtype() != value_tensor.dtype()) { + value_tensor = cast_ad_func(value_tensor, self->tensor.dtype()); + } } - } - // step3.1: Only basic indexing, use OP set_value. - const phi::distributed::ProcessMesh* mesh = nullptr; - if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) { - ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor); - } - self->tensor = set_value_with_tensor__ad_func(self->tensor, - value_tensor, - slice_starts, - slice_ends, - slice_strides, - slice_axes, - decrease_axis, - none_axes); - if (PyCheckTensor(value_obj)) { - // pass the stop_gradient from value to tensor. - // pass stop gradient should be done after CheckInplace in - // set_value__dygraph_function. - if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() && - egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) { - egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false); + // step3.1: Only basic indexing, use OP set_value. + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) { + ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor); + } + self->tensor = set_value_with_tensor__ad_func(self->tensor, + value_tensor, + slice_starts, + slice_ends, + slice_strides, + slice_axes, + decrease_axis, + none_axes); + if (PyCheckTensor(value_obj)) { + // pass the stop_gradient from value to tensor. + // pass stop gradient should be done after CheckInplace in + // set_value__dygraph_function. + if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() && + egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) { + egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false); + } + } + } else { + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self->tensor)) { + ConvertAllInputsToDistTensor(mesh, self->tensor); } + self->tensor = set_value__ad_func(self->tensor, + slice_starts, + slice_ends, + slice_strides, + slice_axes, + decrease_axis, + none_axes, + {1}, + values); } } else { // step3.2: Case for there are advanced indexing. diff --git a/paddle/fluid/pybind/slice_utils.h b/paddle/fluid/pybind/slice_utils.h index 73ad179d6782a..bc3ac16cfe66d 100644 --- a/paddle/fluid/pybind/slice_utils.h +++ b/paddle/fluid/pybind/slice_utils.h @@ -26,9 +26,11 @@ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -508,5 +510,104 @@ static void ParseBoolAndBroadcastIndices( } } +static paddle::Tensor dealWithValues(const paddle::Tensor& tensor, + PyObject* value_obj, + std::vector* values, + const bool trans_to_tensor) { + paddle::Tensor value_tensor; + if (PyCheckTensor(value_obj)) { + value_tensor = reinterpret_cast(value_obj)->tensor; + } else if (py::isinstance(value_obj)) { + paddle::Tensor value_tensor_tmp( + std::make_shared(), + egr::Controller::Instance().GenerateUniqueName()); + py::object value_obj_tmp(py::handle(value_obj), true); + py::object value = value_obj_tmp; + if (tensor.dtype() == phi::DataType::FLOAT32) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::FLOAT64) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::INT32) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::INT64) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::BOOL) { + if (!py::isinstance>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray(value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::COMPLEX64) { + if (!py::isinstance>>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray>( + value_obj_tmp); + } + } else if (tensor.dtype() == phi::DataType::COMPLEX128) { + if (!py::isinstance>>(value_obj_tmp)) { + value = pybind11::detail::CastNumpyArray>( + value_obj_tmp); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "When assign a numpy.np value to a paddle.Tensor, " + "the data type of the paddle.Tensor must be bool, " + "float32, float64, complex64, complex128, int32 or int64, " + "please check the type of tensor.")); + } + SetTensorFromPyArray( + static_cast(value_tensor_tmp.impl().get()), + value, + tensor.place(), + false); + value_tensor = value_tensor_tmp; + } else { + py::object value_obj_tmp(py::handle(value_obj), true); + // convert the value to self data type + if (py::isinstance(value_obj_tmp) || + py::isinstance(value_obj_tmp) || + py::isinstance(value_obj_tmp) || + PyComplex_Check(value_obj)) { + if (tensor.dtype() == phi::DataType::FLOAT32 || + tensor.dtype() == phi::DataType::FLOAT16 || + tensor.dtype() == phi::DataType::BFLOAT16) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::FLOAT64) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::INT32 || + tensor.dtype() == phi::DataType::INT16 || + tensor.dtype() == phi::DataType::INT8 || + tensor.dtype() == phi::DataType::UINT8) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::INT64) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::BOOL) { + values->push_back(value_obj_tmp.cast()); + } else if (tensor.dtype() == phi::DataType::COMPLEX64) { + values->push_back(value_obj_tmp.cast>()); + } else if (tensor.dtype() == phi::DataType::COMPLEX128) { + values->push_back(value_obj_tmp.cast>()); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Value type error. The assign value allows " + "Tensor, numpy.ndarray, integer, float, complex or bool, " + "but received %s.", + Py_TYPE(value_obj))); + } + + if (trans_to_tensor) { + value_tensor = + full_ad_func({1}, (*values)[0], tensor.dtype(), tensor.place()); + } + } + return value_tensor; +} + } // namespace pybind } // namespace paddle diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index cc12d50a6069f..7b4c81cfa323d 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -876,7 +876,7 @@ def __array__(self, dtype=None): array = array.astype(dtype) return array - def pre_deal_index_and_value(self, item, value=None): + def pre_deal_index(self, item): # since in pybind there is no effiency way to transfer Py_Tuple/Py_List/Py_Range to Tensor # we call this function in python level. item = list(item) if isinstance(item, tuple) else [item] @@ -886,17 +886,14 @@ def pre_deal_index_and_value(self, item, value=None): elif isinstance(slice_item, range): item[i] = paddle.to_tensor(list(slice_item)) - if value is not None and not isinstance(value, Variable): - value = paddle.to_tensor(value, dtype=self.dtype) - - return tuple(item), value + return tuple(item) def __getitem__(self, item): - item, _ = pre_deal_index_and_value(self, item) + item = pre_deal_index(self, item) return self._getitem_dygraph(item) def __setitem__(self, item, value): - item, value = pre_deal_index_and_value(self, item, value) + item = pre_deal_index(self, item) return self._setitem_dygraph(item, value) @framework.dygraph_only From 5eaba8ede5e70dca7c0e5109c74b3531ca1b3716 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 18 Jan 2024 04:11:51 +0000 Subject: [PATCH 4/5] remove test_pir --- test/indexing/test_setitem.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index 8e1b0bbe72a04..781b32cb9183b 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -647,7 +647,6 @@ def test_indexing_is_boolean_false(self): np.testing.assert_allclose(res[0], np_data) - @test_with_pir_api def test_combined_indexing_and_value_is_tensor_1(self): # value is tensor with same shape to getitem and index will be adjusted np_data = np.ones((3, 3), dtype='int32') @@ -667,7 +666,6 @@ def test_combined_indexing_and_value_is_tensor_1(self): np.testing.assert_allclose(res[0], np_data) - @test_with_pir_api def test_combined_indexing_and_value_is_tensor_2(self): # value is tensor needed to broadcast and index will be adjusted np_data = np.ones((3, 4, 5, 6), dtype='int32') @@ -690,7 +688,6 @@ def test_combined_indexing_and_value_is_tensor_2(self): np.testing.assert_allclose(res[0], np_data) - @test_with_pir_api def test_combined_indexing_and_value_is_tensor_3(self): # value is tensor and index will be adjusted # and the value rank is less than original tensor From 6fc8f98ed8a204a931a73539f775117563c33266 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 18 Jan 2024 06:20:21 +0000 Subject: [PATCH 5/5] remove one test since 2.6 not support uint8-add --- test/indexing/test_setitem.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/test/indexing/test_setitem.py b/test/indexing/test_setitem.py index 781b32cb9183b..33433c428c030 100644 --- a/test/indexing/test_setitem.py +++ b/test/indexing/test_setitem.py @@ -228,31 +228,6 @@ def test_indexing_is_boolean_false(self): np.testing.assert_allclose(x.numpy(), np_data) - def test_combined_indexing_and_value_is_tensor_1(self): - # value is tensor with same shape to getitem and index will be adjusted - np_data = np.ones((3, 3)).astype(self.ndtype) - value_data = np.array([-1, -1, -1]).astype(self.ndtype) - - if self.dtype == 'bfloat16': - np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) - value_data = convert_uint16_to_float( - convert_float_to_uint16(value_data) - ) - if self.dtype == 'complex64' or self.dtype == 'complex128': - np_data = np_data + 1j * np_data - value_data = value_data + 1j * value_data - - x = paddle.to_tensor(np_data, dtype=self.dtype) - v = paddle.to_tensor(value_data, dtype=self.dtype) - - np_data[:, [0, 2]] = np_data[:, [0, 2]] + np.expand_dims(value_data, -1) - x[:, [0, 2]] = x[:, [0, 2]] + v.unsqueeze(-1) - - if self.dtype == 'bfloat16': - x = paddle.cast(x, dtype='float32') - - np.testing.assert_allclose(x.numpy(), np_data) - def test_combined_indexing_and_value_is_tensor_2(self): # value is tensor needed to broadcast and index will be adjusted np_data = np.ones((3, 4, 5, 6)).astype(self.ndtype)