Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add index_put api #52886

Merged
merged 28 commits into from
May 10, 2023
Merged

Conversation

Courtesy-Xs
Copy link
Contributor

@Courtesy-Xs Courtesy-Xs commented Apr 13, 2023

PR types

New features

PR changes

APIs

Description

This PR add index_put and index_put_ API for Paddle, please refer to PaddlePaddle API doc for details.

(Supplementary Note: Due to some indexing mechanism problems of the Paddle framework, the performance of Paddle's ways to index is much slower than Torch, but the overall reconstruction is a process that takes time, so some advanced indexing with poor performance is firstly extracted for optimization and will expose them in the type of paddle API which are index_put and index_put_ API for users.
Advanced Indexing means using tensor as subscript to index a tensor. Under the functions supported by index_put API, its performance is far better than directly c-order indexing in paddle)

@paddle-bot
Copy link

paddle-bot bot commented Apr 13, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Courtesy-Xs Courtesy-Xs marked this pull request as draft April 17, 2023 12:46
@Courtesy-Xs Courtesy-Xs marked this pull request as ready for review April 17, 2023 12:47
@Courtesy-Xs Courtesy-Xs marked this pull request as draft April 17, 2023 12:50
@Courtesy-Xs Courtesy-Xs marked this pull request as ready for review April 17, 2023 12:50
#include "paddle/phi/kernels/index_put_grad_kernel.h"
#include <numeric>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu kernel里面不需要加这些gpu相关的头文件

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

int64_t offset = 0;

for (size_t i = 0; i < Rank; ++i) {
cur_ix = (int64_t(*(indices[i] + idx)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据类型转换用static_cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

value_grad->dtype(),
false,
&value_grad_dims_without1);
phi::ReshapeInferKernel<Context>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的目的是对value_gradResize吧?value_grad是输出,直接用value_grad->Resize(...)就行?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value_grad的size不能调用resize变化的,value_grad的dims会影响到反向梯度的shape,需保持与前向的value的shape一致

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但你ReshapeInferKernel的调用,不也会修改value_grad的shape吗?我的意思是,在L190再调用一次value_grad->Resize,直接再次设置value_grad的shape,也可避免ReshapeInferKernel中的一次memcpy。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的ReshapeInferKernel本身并没有修改value_grad的shape

template <typename T, size_t Rank>
void set_zero_kernel(const int64_t N,
const int64_t** indices,
phi::Array<int64_t, Rank> stride,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU Kernel没必要用phi::Array,直接用const std::vector<int64_t>&const phi::DDim&类型就行,还能避免拷贝。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}

template <typename T, typename Context, size_t Rank>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU Kernel就不要将Rank作为模板了,你单测覆盖率没过正式因为Rank

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/nonzero_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#include "paddle/phi/kernels/split_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要include这么多头文件

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


#include <vector>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要include fluid下面的头文件,使用phi目录下的代替

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -916,6 +916,7 @@ set_tests_properties(test_imperative_selected_rows_to_lod_tensor
PROPERTIES TIMEOUT 200)
set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_index_add_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_index_put_op PROPERTIES TIMEOUT 120)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个TIMEOUT一定要设置吗,默认是多少?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不设置的话,CI会超时,和CI的同学确认过了,默认的值的话,很小,貌似不过15s,当时CI报错超过15s直接timeout了,具体是多少不确定

@@ -0,0 +1,826 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2022 -> 2023

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



"""
assert len(indices) != 0, "indices can't be empty"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

动态图可不加assert,算子内部负责检查吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented May 6, 2023

Sorry to inform you that 7b71a3a's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@Xreki
Copy link
Contributor

Xreki commented May 8, 2023

image

这个改一下吧。

@Courtesy-Xs
Copy link
Contributor Author

image

这个改一下吧。

这个看起来是新增API的问题,确认了一下,改了yaml都会这样,API和Op的参数是对齐的

zyfncg
zyfncg previously approved these changes May 8, 2023
@@ -3249,6 +3249,21 @@ void MoeInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

void IndexPutInferMeta(const MetaTensor& x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferMeta按照字母序放置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#pragma once

#include <vector>
#include "paddle/phi/common/place.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

place.h看上去不需要include

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#pragma once

#include <vector>
#include "paddle/phi/common/place.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

infer_meta :
func : IndexPutInferMeta
kernel :
func : index_put
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入x和indices的数据类型不同,需要指定按照谁的数据类型来选择kernel,关键字为data_type,写法如后面紧跟的index_sample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const int64_t* pd_indices[7];
for (size_t i = 0; i < indices_v.size(); ++i) {
pd_indices[i] = indices_v[i]->data<int64_t>();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L108 - L111既然后续还会用到,就挪到L98吧,删除L121 - L124的重复代码。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

value_grad->dtype(),
false,
&value_grad_dims_without1);
phi::ReshapeInferKernel<Context>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

但你ReshapeInferKernel的调用,不也会修改value_grad的shape吗?我的意思是,在L190再调用一次value_grad->Resize,直接再次设置value_grad的shape,也可避免ReshapeInferKernel中的一次memcpy。

T* out = dev_ctx.template Alloc<T>(p_res);
range_kernel<T>(N, out);
return res;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

并不是说CPU、GPU Kernel里面重复,而是前向、反向中也有重复。通过模板+宏、或者设置不同的函数名来解决。

const int64_t** indices,
const phi::DDim& stride,
const phi::DDim& shape,
int64_t isSingleValTensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isSingleValTensor -> is_single_val_tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

DenseTensor* out) {
auto* x_data = x.data<T>();
auto* val_data = value.data<T>();
bool isInitialized = out->initialized();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isInitialized -> is_initialized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#include "paddle/phi/kernels/reshape_kernel.h"
#include "paddle/phi/kernels/split_kernel.h"

namespace phi {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一层namespace funcs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

phi::DenseTensor res_tensor(tensor.dtype());
res_tensor.Resize(res_dim);
ExpandKernel<T, Context>(
dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里调Reshape和Expand都会产生memcpy,实际上只需要获得相应的dims

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我开始也想过是不是可以直接resize就行,之前在
tmp_indices_v.emplace_back(DenseTensor(phi::DataType::INT64).Resize(phi::make_ddim({nonzero_indices.dims()[0],1})));替换为
tmp_indices_v.emplace_back(DenseTensor(phi::DataType::INT64).Resize(phi::make_ddim({nonzero_indices.dims()[0]})));的时候我尝试过是否通过resize可以减少一些reshape操作

但是在这里我受限的点在于我需要一个能够满足expand关系的src tensor和一个des tensor来操作,但是我并不能修改tensor这个对象,因为它是一个const reference,所以这里两次的拷贝,可能是一个必要的

int64_t** indices,
phi::Array<int64_t, Rank> stride,
phi::Array<int64_t, Rank> shape,
int64_t isSingleValTensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isSingleValTensor -> is_single_val_tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

T* out = dev_ctx.template Alloc<T>(p_res); \
range_kernel<T>(N, out); \
return res; \
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

真是没想到你会把整个函数全部写成一个宏。一个建议的最简单的修改方式如下:

template <typename T>
void range_kernel(int64_t N, T* out) {
  ... 
}

template <typename T, typename Context>
phi::DenseTensor GetRangeTensor(const Context& dev_ctx, int64_t N, phi::DataType dtype) {
  ...
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
__global__ void range_cuda_kernel(int64_t N, T* out) {
  ...
}

template <typename T, typename Context>
phi::DenseTensor GetRangeCudaTensor(
    const Context& dev_ctx, int64_t N, phi::DataType dtype) {
  ...
}
#endif

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Xreki
Xreki previously approved these changes May 8, 2023
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Xreki
Xreki previously approved these changes May 9, 2023
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


Args:
x (Tensor) : The Source Tensor. Supported data types are int32, int64, float16, float32, float64, bool.
indices (Tensor): The tuple of Tensor containing the indices to index.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是List / tuple of Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是tuple of tensor,对齐的torch对应的API的用法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

indices (Tuple of Tensor): The tuple of Tensor containing the indices to index.
The data type of ``tensor in indices`` must be int32, int64 or bool
value (Tensor): The tensor used to be assigned to x.
accummulate (Bool): Whether the elements in values are added to x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有默认值的参数需要注明optional,以及default是什么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def index_put(x, indices, value, accumulate=False, name=None):
"""
Outplace version of ``index_put_`` API, the output Tensor will be inplaced with input ``x``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一般是会说 index_put_ 是 index_put 的 inplace 版本,能否反过来说辛苦文档pm确认下 @sunzhongkai588


Returns:
Tensor, same dimention and dtype with x.
Examples:
Copy link
Contributor

@Ligoml Ligoml May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要确认一下官网预览效果,等 PR-CI-Paddle-Doc-Preview 跑完

lanxianghit
lanxianghit previously approved these changes May 9, 2023
zyfncg
zyfncg previously approved these changes May 9, 2023
@Courtesy-Xs Courtesy-Xs dismissed stale reviews from zyfncg and lanxianghit via b09221f May 9, 2023 12:37
Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_index_put_op PROPERTIES TIMEOUT 120)

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_index_put_op PROPERTIES TIMEOUT 120)

@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators May 9, 2023
@PaddlePaddle PaddlePaddle unlocked this conversation May 9, 2023
@Xreki Xreki merged commit f3393f4 into PaddlePaddle:develop May 10, 2023
@Courtesy-Xs Courtesy-Xs deleted the clear_add_index_put_api branch July 7, 2023 03:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants