Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在fused_rope算子中增加rotate_half实现方式 #56401

Merged
merged 8 commits into from
Sep 4, 2023

Conversation

tianhaodongbd
Copy link
Contributor

@tianhaodongbd tianhaodongbd commented Aug 17, 2023

PR types

Others

PR changes

OPs

Description

Pcard-70459

在fused_rope算子中增加rotate_half实现方式,通过use_neox_rotary_style这样一个变量来控制,true是rotate_every_two实现、false是rotate_half实现,默认值为true

@paddle-bot
Copy link

paddle-bot bot commented Aug 17, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -102,5 +103,91 @@ __global__ void VectorizedFusedRopeKernel(phi::Array<const T*, 3> ins_data,
}
}

template <typename T, typename MPType, int VecSize = 2>
__global__ void VectorizedFusedRopeWithRotateHalfKernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • CUDA上面可以定义__device__函数,__device__函数可以被__global__函数调用
  • 拆分成2个__global__函数也可以,但还是要避免大段的代码拷贝

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -27,6 +29,7 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None):
v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2.
cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2.
use_neox_rotary_style(optional|bool): Use "rotate_every_two" when use_neox_rotary_style is True, use "ratate_half" when use_neox_rotary_style is False. Default True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么解释含义并不直观,rotate_every_tworotate_half并不是大家都知道的通用的表意。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

# to [1, 1, seq_len, head_dim]
perm = [0, 2, 1, 3]
sin_tensor = paddle.transpose(x=sin_tensor, perm=perm)
cos_tensor = paddle.transpose(x=cos_tensor, perm=perm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_neox_rotary_styleTrue或者False,只有qkv的更新逻辑有差异,sincos的计算逻辑并没有差异,因此sincos的计算逻辑没有必要在if-else两个分支中重复。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

if (position_ids.get_ptr()) {
position_ids_data = position_ids->data<int64_t>();

flag_position_ids = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

并不需要加这么个flag,L63将position_ids_data初始化为空,Kernel里面可以用指针是否为空判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

bool flag_position_ids = false;
if (position_ids.get_ptr()) {
position_ids_data = position_ids->data<int64_t>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要对position_ids的维度进行检查,且只有在传入了sincos的时候才需要用position_ids,且需要修改sin、cos的shape检查逻辑。

image

也就是说,sin、cos依据position_ids里面的坐标切片访问后,shape才是[1, seq_len, 1, head_dim],传进来的可能是一个比较大的shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

paddle/phi/kernels/fusion/gpu/fused_rope_utils.h Outdated Show resolved Hide resolved
@@ -27,6 +35,8 @@ def fused_rotary_position_embedding(q, k=None, v=None, sin=None, cos=None):
v (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if v must be [batch_size, seq_len, num_heads, head_dim] and head_dim must be a multiple of 2.
sin (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if sin must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2.
cos (optional|Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. The shape if cos must be [seq_len, head_dim] or [1, 1, seq_len, head_dim] and head_dim must be a multiple of 2.
position_ids (optional|Tensor): The input tensor. The data type is int64. The shape if position_ids must be [batch_size, seq_len].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape if -> The shape of,文档里面参数的格式应该是:position_ids (Tensor, optional)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Xreki
Xreki previously approved these changes Sep 1, 2023
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. PR Title建议改成英文

Comment on lines 50 to 75
import paddle
from paddle.incubate.nn.functional import fused_rotary_position_embedding

q = paddle.randn([1, 1, 4, 10], dtype='float16')
k = paddle.randn([1, 1, 4, 10], dtype='float16')
v = paddle.randn([1, 1, 4, 10], dtype='float16')
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v)
# batch_size = 2
# seq_len = 8
# num_heads = 2
# head_dim = 10

x = paddle.randn([1, 1, 1, 10], dtype='float16')
y = paddle.randn([1, 1, 1, 10], dtype='float16')
# q, k, v: [batch_size, seq_len, num_heads, head_dim]
q = paddle.randn([2, 8, 2, 10], dtype='float16')
k = paddle.randn([2, 8, 2, 10], dtype='float16')
v = paddle.randn([2, 8, 2, 10], dtype='float16')

# sin, cos: [1, seq_len, 1, head_dim]
x = paddle.randn([1, 8, 1, 10], dtype='float16')
y = paddle.randn([1, 8, 1, 10], dtype='float16')
sin = paddle.sin(x)
cos = paddle.cos(y)
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos)

# position_ids: [batch_size, seq_len]
position_ids = paddle.randint(high=8, shape=[2, 8], dtype='int64')

# out_q, out_k, out_v: [batch_size, seq_len, num_heads, head_dim]
out_q, out_k, out_v = fused_rotary_position_embedding(q, k, v, sin=sin, cos=cos, position_ids=position_ids, use_neox_rotary_style=False)
print(out_q.shape)
# [2, 8, 2, 10]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码示例请严格按照 Google style 样式,即代码前需要加上>>>...,若有输出(如print(out_q.shape)) 则要在输出后加上准确的输出结果,参考 API 文档写作说明—代码示例文档示例代码书写规范

注意

  • 本代码部分有 randn 这类带有随机性的api,请在代码部分增加 seed,以保证输出结果固定,便于检查。
  • # required: gpu 本环境是需要GPU环境吗?是的话需要在代码开头增加 doctest 指令: >>> # doctest: +REQUIRES(env:GPU)

Xreki
Xreki previously approved these changes Sep 1, 2023
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,我看API好像是公开的,记得补充一下中文文档

MPType* sin_value = out_sin;
MPType* cos_value = out_cos;

if (flag_sin_cos) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数的命名似乎不是好理解,reuse_***?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下一个PR修改

@@ -148,11 +148,11 @@
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index

- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos)
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for add inputs

Copy link
Contributor

@MARD1NO MARD1NO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -86,21 +89,42 @@ void FusedRopeGradKernel(const Context& dev_ctx,
sin_cos_data[1] = cos->data<T>();

flag_sin_cos = true;

if (position_ids.get_ptr()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该可以直接 if (position_ids) 的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下一个pr修改

num_inputs,
div_c);
if (use_neox_rotary_style) {
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得kernel名字改成:

VectorizedFusedNeoxRopeKernel 是不是好点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下一个pr修改

@Xreki Xreki merged commit c089a2a into PaddlePaddle:develop Sep 4, 2023
25 of 26 checks passed
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
* add rotate_half in fused_rope

* add position_ids in fused_rope

* modified examples about fused_rope

* add set_device in examples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants