Skip to content

Commit

Permalink
Optimize FastGelu with float2 and float4 vectorized kernels on ROCm (#…
Browse files Browse the repository at this point in the history
…11491)

* Using vectorized loads (float2) for fp16 to improve performance

* Fix a few warnings from cpplint

* Fix a few warnings from cpplint

* Use __float2half2_rn and fix some cpplint warnings

* Move some computaions to LaunchFastGeluKernel

* Fix some Lint C++ warning

* Using vectorized loads (float4) for fp16 to improve performance

* Switch   whether to optimize FastGelu with float4 vectorization

* Switch to float4 memory access based on input_length in FastGelu

* Comment how to set the threshold of float2 and float4 vectorized kernels

* Add FastGelu fp16 unit tests for bias_length = 2 and 8

* Make vectorized kernels generic with aligned_vector

* Unify the vectorized kernels with/without bias

* Refactor the code to suppress cpplint warnings

* Solve formatting issues

* Remove cudaDeviceProp from FastGeluKernel and LaunchFastGeluKernel

* Move fast_gelu_impl.h to rocm/bert

* Fix some Lint C++ warnings and code alignment
  • Loading branch information
hubertlu-tw authored Jun 24, 2022
1 parent 088bc74 commit f4ba199
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 144 deletions.
73 changes: 73 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Modifications: Remove GetDeviceProp in LaunchFastGeluKernel.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/miopen_common.h"
#include "contrib_ops/rocm/bert/fast_gelu.h"
#include "contrib_ops/rocm/bert/fast_gelu_impl.h"
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
#include "contrib_ops/rocm/bert/transformer_common.h"

namespace onnxruntime {
namespace contrib {
namespace rocm {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
FastGelu, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
FastGelu<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

using namespace ONNX_NAMESPACE;

template <typename T>
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
const TransformerOptions* options = TransformerOptions::GetInstance();
use_half2_ = !options->DisableHalf2();
}

template <typename T>
Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));

const Tensor* input = context->Input<Tensor>(0);
const Tensor* bias = context->Input<Tensor>(1);
Tensor* output = context->Output(0, input->Shape());

int64_t input_length = input->Shape().Size();
if (input_length == 0) {
return Status::OK();
}
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToHipType<T>::MappedType HipT;

if (!LaunchFastGeluKernel<HipT>(Stream(),
static_cast<int>(input_length),
static_cast<int>(bias_length),
reinterpret_cast<const HipT*>(input->template Data<T>()),
(nullptr != bias) ? reinterpret_cast<const HipT*>(bias->template Data<T>()) : nullptr,
reinterpret_cast<HipT*>(output->template MutableData<T>()),
use_half2_)) {
HIP_CALL(hipGetLastError());
return Status(common::ONNXRUNTIME, common::FAIL);
}

return Status::OK();
}

} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
26 changes: 26 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace rocm {

using namespace onnxruntime::rocm;

template <typename T>
class FastGelu final : public RocmKernel {
public:
FastGelu(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;

private:
bool use_half2_;
};

} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
246 changes: 105 additions & 141 deletions onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ limitations under the License.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Modifications: Add FastGeluKernelVec to leverage vectorized load/write
// and modify FastGeluKernel to get better performance.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_call.h"
Expand All @@ -32,21 +37,15 @@ namespace onnxruntime {
namespace contrib {
namespace rocm {

// constants for approximating the normal cdf
constexpr float A = 0.5;

constexpr float B = 0.7978845608028654; // sqrt(2.0/M_PI)

constexpr float C = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI)

constexpr float one = 1.0;
constexpr float two = 2.0;

template <typename T, unsigned TPB>
__global__ void FastGeluKernel(const T a, const T b, const T c, const T oneT, const T twoT,
int input_length, int bias_length, const T* input, const T* bias, T* output) {
__global__ void FastGeluKernel(int input_length, int bias_length, const T* input, const T* bias, T* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;

// constants for approximating the normal cdf
const T a = T(0.5f);
const T b = T(0.7978845608028654f); // sqrt(2.0/M_PI)
const T c = T(0.035677408136300125f); // 0.044715 * sqrt(2.0/M_PI)
const T oneT = T(1.0f);
const T twoT = T(2.0f);
if (idx < input_length) {
const T x = input[idx];
const T in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
Expand All @@ -60,153 +59,118 @@ __global__ void FastGeluKernel(const T a, const T b, const T c, const T oneT, co
}
}

template <unsigned TPB>
__global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, const half2 one2, const half2 two2,
int input_length, int bias_length, const half2* input, const half2* bias,
half2* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;

if (idx < input_length) {
const half2 x = input[idx];
const half2 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);

// const half2 cdf = a + a * _Tanh(in * (c * in * in + b));
const half2 u = two2 * in * (c * in * in + b);
const half2 emu = h2exp(-u);
const half2 cdf = a + a * (two2/(one2 + emu) - one2);

output[idx] = in * cdf;
}
}

template <unsigned TPB>
__global__ void FastGeluKernel4Bias(const half2 a, const half2 b, const half2 c, const half2 one2, const half2 two2,
int input_length, int bias_length, const float2* input, const float2* bias,
float2* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;

template <typename T, unsigned TPB, int ILP>
__global__ void FastGeluKernelVec(int input_length, int bias_length, const T* input, const T* bias,
T* output) {
using VecT = aligned_vector<T, ILP>;
const T a = T(0.5f);
const T b = T(0.7978845608028654f);
const T c = T(0.035677408136300125f);
const T oneT = T(1.0f);
const T twoT = T(2.0f);

const int idx = (blockIdx.x * TPB + threadIdx.x) * ILP;
if (idx < input_length) {
float2 input_vec = input[idx];
float2 bias_vec = bias[idx % bias_length];
float2 output_vec = output[idx];

half2* input_half = reinterpret_cast<half2*>(&input_vec);
half2* bias_half = reinterpret_cast<half2*>(&bias_vec);
half2* output_half = reinterpret_cast<half2*>(&output_vec);

half2 lo_data = input_half[0];
half2 hi_data = input_half[1];
half2 lo_bias = bias_half[0];
half2 hi_bias = bias_half[1];

lo_data += lo_bias;
hi_data += hi_bias;

const half2 lo_u = two2 * lo_data * (c * lo_data * lo_data + b);
const half2 hi_u = two2 * hi_data * (c * hi_data * hi_data + b);
const half2 lo_emu = h2exp(-lo_u);
const half2 hi_emu = h2exp(-hi_u);
const half2 lo_cdf = a + a * (two2/(one2 + lo_emu) - one2);
const half2 hi_cdf = a + a * (two2/(one2 + hi_emu) - one2);

output_half[0] = lo_data * lo_cdf;
output_half[1] = hi_data * hi_cdf;

output[idx] = output_vec;
}
}

template <unsigned TPB>
__global__ void FastGeluKernel4(const half2 a, const half2 b, const half2 c, const half2 one2, const half2 two2,
int input_length, const float2* input, float2* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;

if (idx < input_length) {
float2 input_vec = input[idx];
float2 output_vec = output[idx];

half2* input_half = reinterpret_cast<half2*>(&input_vec);
half2* output_half = reinterpret_cast<half2*>(&output_vec);

half2 lo_data = input_half[0];
half2 hi_data = input_half[1];

const half2 lo_u = two2 * lo_data * (c * lo_data * lo_data + b);
const half2 hi_u = two2 * hi_data * (c * hi_data * hi_data + b);
const half2 lo_emu = h2exp(-lo_u);
const half2 hi_emu = h2exp(-hi_u);
const half2 lo_cdf = a + a * (two2/(one2 + lo_emu) - one2);
const half2 hi_cdf = a + a * (two2/(one2 + hi_emu) - one2);

output_half[0] = lo_data * lo_cdf;
output_half[1] = hi_data * hi_cdf;
using VecT = aligned_vector<T, ILP>;
T input_v[ILP];
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
T output_v[ILP];
VecT* output_val = reinterpret_cast<VecT*>(&output_v);
T bias_v[ILP];
if (bias != nullptr) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[idx % bias_length]);
}

output[idx] = output_vec;
#pragma unroll
for (int i = 0; i < ILP; i++) {
const T x = (bias == nullptr) ? input_v[i] : input_v[i] + bias_v[i];
const T u = twoT * x * (c * x * x + b);
const T emu = __expf(-u);
const T cdf = a + a * (twoT/(oneT + emu) - oneT);
output_v[i] = x * cdf;
}
*(reinterpret_cast<VecT*>(&output[idx])) = *reinterpret_cast<VecT*>(&output_v[0]);
}
}

template <>
bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int input_length, int bias_length,
bool LaunchFastGeluKernel(hipStream_t stream, int input_length, int bias_length,
const float* input, const float* bias, float* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<float, blockSize>), dim3(gridSize), dim3(blockSize), 0,
stream, A, B, C, one, two, input_length, bias_length, input, bias, output);

constexpr int block_size = 256;
const int grid_size = (input_length + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<float, block_size>), dim3(grid_size), dim3(block_size), 0,
stream, input_length, bias_length, input, bias, output);
return HIP_CALL(hipPeekAtLastError());
}

template <>
bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int input_length, int bias_length,
bool LaunchFastGeluKernel(hipStream_t stream, int input_length, int bias_length,
const half* input, const half* bias, half* output, bool use_half2) {
constexpr int blockSize = 256;
if (use_half2 && prop.major >= 7 && (0 == (bias_length % 4) || 0 == (bias_length & 1))) {
const half2 A2 = __float2half2_rn(A);
const half2 B2 = __float2half2_rn(B);
const half2 C2 = __float2half2_rn(C);
const half2 one2 = __float2half2_rn(one);
const half2 two2 = __float2half2_rn(two);
if (0 == (bias_length % 4)) {
const int n = input_length / 4;
const int gridSize = (n + blockSize - 1) / blockSize;
const float2* input4 = reinterpret_cast<const float2*>(input);
const float2* bias4 = reinterpret_cast<const float2*>(bias);
float2* output4 = reinterpret_cast<float2*>(output);
if (bias == nullptr)
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel4<blockSize>), dim3(gridSize), dim3(blockSize), 0,
stream, A2, B2, C2, one2, two2, n, input4, output4);
else
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel4Bias<blockSize>), dim3(gridSize), dim3(blockSize), 0,
stream, A2, B2, C2, one2, two2, n, bias_length / 4, input4, bias4, output4);
} else {
const int n = input_length / 2;
const int gridSize = (n + blockSize - 1) / blockSize;
const half2* input2 = reinterpret_cast<const half2*>(input);
const half2* bias2 = reinterpret_cast<const half2*>(bias);
half2* output2 = reinterpret_cast<half2*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel2<blockSize>), dim3(gridSize), dim3(blockSize), 0,
stream, A2, B2, C2, one2, two2, n, bias_length / 2, input2, bias2, output2);
}
constexpr int block_size = 256;
if (use_half2) {
if (bias != nullptr) {
if (0 == (bias_length % 8) && (input_length >= 3145728)) { // 3145728=8*128*3072
const int grid_size = (input_length / 8 + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec<half, block_size, 8>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
} else if (0 == (bias_length % 4)) {
const int grid_size = (input_length / 4 + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec<half, block_size, 4>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
} else if (0 == (bias_length % 2)) {
const int grid_size = (input_length / 2 + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec<half, block_size, 2>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
} else {
const int grid_size = (input_length + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<half, block_size>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
}
} else {
if (0 == (input_length % 8) && (input_length >= 3145728)) { // 3145728=8*128*3072
const int grid_size = (input_length / 8 + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec<half, block_size, 8>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
} else if (0 == (input_length % 4)) {
const int grid_size = (input_length / 4 + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec<half, block_size, 4>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
} else if (0 == (input_length % 2)) {
const int grid_size = (input_length / 2 + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec<half, block_size, 2>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
} else {
const int grid_size = (input_length + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<half, block_size>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
}
}
} else {
const int gridSize = (input_length + blockSize - 1) / blockSize;
const half oneT = half(one);
const half twoT = half(two);
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<half, blockSize>), dim3(gridSize), dim3(blockSize), 0,
stream, A, B, C, oneT, twoT, input_length, bias_length, input, bias, output);
const int grid_size = (input_length + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<half, block_size>), dim3(grid_size),
dim3(block_size), 0, stream, input_length, bias_length,
input, bias, output);
}

return HIP_CALL(hipPeekAtLastError());
}

template <>
bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int input_length, int bias_length,
bool LaunchFastGeluKernel(hipStream_t stream, int input_length, int bias_length,
const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
const BFloat16 oneT = BFloat16(one);
const BFloat16 twoT = BFloat16(two);
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<BFloat16, blockSize>), dim3(gridSize), dim3(blockSize), 0,
stream, A, B, C, oneT, twoT, input_length, bias_length, input, bias, output);
constexpr int block_size = 256;
const int grid_size = (input_length + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel<BFloat16, block_size>), dim3(grid_size), dim3(block_size), 0,
stream, input_length, bias_length, input, bias, output);
return HIP_CALL(hipPeekAtLastError());
}

Expand Down
Loading

0 comments on commit f4ba199

Please sign in to comment.