Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use elementwise to optimize gelu backward implementation on GPU #38263

Merged
merged 6 commits into from
Dec 22, 2021

Conversation

Zjq9409
Copy link
Contributor

@Zjq9409 Zjq9409 commented Dec 19, 2021

PR types

Performance optimization

PR changes

OPs

Describe

使用elementwise优化gelu算子GPU反向计算,前向计算+反向计算优化后性能数据如下:
image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ZzSean
Copy link
Contributor

ZzSean commented Dec 20, 2021

PR名字描述稍微详细一点,跟上个PR对应

@Zjq9409 Zjq9409 changed the title optimize gelu backward use elementwise to optimize gelu backward implementation on GPU Dec 20, 2021
@Zjq9409
Copy link
Contributor Author

Zjq9409 commented Dec 21, 2021

PR名字描述稍微详细一点,跟上个PR对应

Done.

tanh(kAlpha * x * (one + static_cast<MPType>(0.044715) * x * x));
auto ans =
half * x * ((one - tanh_out * tanh_out) *
(kAlpha + static_cast<MPType>(0.1070322243) * x * x)) +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要出现公式以外的魔鬼数字,都用表达式来代替

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto tanh_out =
tanh(kAlpha * x * (one + static_cast<MPType>(0.044715) * x * x));
auto ans =
half * x * ((one - tanh_out * tanh_out) *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

公式再化简一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -12,9 +12,76 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件引用已经在paddle/fluid/operators/amp/fp16_type_traits.h 引用过了,可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

(one + tanh(static_cast<MPType>(0.79788456) * x *
(one + static_cast<MPType>(0.044715) * x * x)));
MPType half = static_cast<MPType>(0.5);
MPType decimal = static_cast<MPType>(0.044715);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
gelu的公式,其中“0.044715”是一个固定的常量值

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.044715

可以用宏定义的方式表达这个常量值, 比如:

#define GELU_CONSTANT  0.044715

同时将这个宏定义放在通用文件夹中,比如gelu_op.h 中,同步修改使用了0.044715 的代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType decimal = static_cast<MPType>(0.044715);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分的魔术数还是存在

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

(one + tanh(static_cast<MPType>(0.79788456) * x *
(one + static_cast<MPType>(0.044715) * x * x)));
MPType half = static_cast<MPType>(0.5);
MPType decimal = static_cast<MPType>(0.044715);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.044715

可以用宏定义的方式表达这个常量值, 比如:

#define GELU_CONSTANT  0.044715

同时将这个宏定义放在通用文件夹中,比如gelu_op.h 中,同步修改使用了0.044715 的代码

MPType kBeta = kAlpha * decimal * static_cast<MPType>(3);
auto tanh_out = tanh(kAlpha * x * (one + decimal * x * x));
auto temp = (one - tanh_out * tanh_out) * (kAlpha + kBeta * x * x);
auto ans = half * x * temp + half * (one + tanh_out);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分的计算中应该存在多次重复计算的参数,比如:x^3 ,可以把这类参数挑出来,减少计算量

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JamesLim-sy JamesLim-sy merged commit 858e435 into PaddlePaddle:develop Dec 22, 2021
zmxdream pushed a commit to zmxdream/Paddle that referenced this pull request Dec 25, 2021
…lePaddle#38263)

* optimize gelu backward

* optimize gelu backward

* optimize code

* Number to expression

* Replacement number
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants