From f4ba199baddbf176369a8fba7ca53a328a361a6c Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Fri, 24 Jun 2022 12:46:17 -0700 Subject: [PATCH] Optimize FastGelu with float2 and float4 vectorized kernels on ROCm (#11491) * Using vectorized loads (float2) for fp16 to improve performance * Fix a few warnings from cpplint * Fix a few warnings from cpplint * Use __float2half2_rn and fix some cpplint warnings * Move some computaions to LaunchFastGeluKernel * Fix some Lint C++ warning * Using vectorized loads (float4) for fp16 to improve performance * Switch whether to optimize FastGelu with float4 vectorization * Switch to float4 memory access based on input_length in FastGelu * Comment how to set the threshold of float2 and float4 vectorized kernels * Add FastGelu fp16 unit tests for bias_length = 2 and 8 * Make vectorized kernels generic with aligned_vector * Unify the vectorized kernels with/without bias * Refactor the code to suppress cpplint warnings * Solve formatting issues * Remove cudaDeviceProp from FastGeluKernel and LaunchFastGeluKernel * Move fast_gelu_impl.h to rocm/bert * Fix some Lint C++ warnings and code alignment --- .../contrib_ops/rocm/bert/fast_gelu.cc | 73 ++++++ onnxruntime/contrib_ops/rocm/bert/fast_gelu.h | 26 ++ .../contrib_ops/rocm/bert/fast_gelu_impl.cu | 246 ++++++++---------- .../contrib_ops/rocm/bert/fast_gelu_impl.h | 21 ++ .../test/contrib_ops/fastgelu_op_test.cc | 96 ++++++- tools/ci_build/amd_hipify.py | 3 + 6 files changed, 321 insertions(+), 144 deletions(-) create mode 100644 onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc create mode 100644 onnxruntime/contrib_ops/rocm/bert/fast_gelu.h create mode 100644 onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc new file mode 100644 index 0000000000000..c89c7f6f41628 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Modifications: Remove GetDeviceProp in LaunchFastGeluKernel. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/miopen_common.h" +#include "contrib_ops/rocm/bert/fast_gelu.h" +#include "contrib_ops/rocm/bert/fast_gelu_impl.h" +#include "contrib_ops/cpu/bert/bias_gelu_helper.h" +#include "contrib_ops/rocm/bert/transformer_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + FastGelu, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + FastGelu); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +using namespace ONNX_NAMESPACE; + +template +FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { + const TransformerOptions* options = TransformerOptions::GetInstance(); + use_half2_ = !options->DisableHalf2(); +} + +template +Status FastGelu::ComputeInternal(OpKernelContext* context) const { + ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); + + const Tensor* input = context->Input(0); + const Tensor* bias = context->Input(1); + Tensor* output = context->Output(0, input->Shape()); + + int64_t input_length = input->Shape().Size(); + if (input_length == 0) { + return Status::OK(); + } + int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); + typedef typename ToHipType::MappedType HipT; + + if (!LaunchFastGeluKernel(Stream(), + static_cast(input_length), + static_cast(bias_length), + reinterpret_cast(input->template Data()), + (nullptr != bias) ? reinterpret_cast(bias->template Data()) : nullptr, + reinterpret_cast(output->template MutableData()), + use_half2_)) { + HIP_CALL(hipGetLastError()); + return Status(common::ONNXRUNTIME, common::FAIL); + } + + return Status::OK(); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h new file mode 100644 index 0000000000000..3da9c126490db --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class FastGelu final : public RocmKernel { + public: + FastGelu(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + bool use_half2_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu index b973073baee8f..a8710d9869199 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu @@ -21,6 +21,11 @@ limitations under the License. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// Modifications: Add FastGeluKernelVec to leverage vectorized load/write +// and modify FastGeluKernel to get better performance. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Licensed under the MIT License. + #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/shared_inc/rocm_call.h" @@ -32,21 +37,15 @@ namespace onnxruntime { namespace contrib { namespace rocm { -// constants for approximating the normal cdf -constexpr float A = 0.5; - -constexpr float B = 0.7978845608028654; // sqrt(2.0/M_PI) - -constexpr float C = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI) - -constexpr float one = 1.0; -constexpr float two = 2.0; - template -__global__ void FastGeluKernel(const T a, const T b, const T c, const T oneT, const T twoT, - int input_length, int bias_length, const T* input, const T* bias, T* output) { +__global__ void FastGeluKernel(int input_length, int bias_length, const T* input, const T* bias, T* output) { const int idx = blockIdx.x * TPB + threadIdx.x; - + // constants for approximating the normal cdf + const T a = T(0.5f); + const T b = T(0.7978845608028654f); // sqrt(2.0/M_PI) + const T c = T(0.035677408136300125f); // 0.044715 * sqrt(2.0/M_PI) + const T oneT = T(1.0f); + const T twoT = T(2.0f); if (idx < input_length) { const T x = input[idx]; const T in = (bias == nullptr) ? x : (x + bias[idx % bias_length]); @@ -60,153 +59,118 @@ __global__ void FastGeluKernel(const T a, const T b, const T c, const T oneT, co } } -template -__global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, const half2 one2, const half2 two2, - int input_length, int bias_length, const half2* input, const half2* bias, - half2* output) { - const int idx = blockIdx.x * TPB + threadIdx.x; - - if (idx < input_length) { - const half2 x = input[idx]; - const half2 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]); - - // const half2 cdf = a + a * _Tanh(in * (c * in * in + b)); - const half2 u = two2 * in * (c * in * in + b); - const half2 emu = h2exp(-u); - const half2 cdf = a + a * (two2/(one2 + emu) - one2); - - output[idx] = in * cdf; - } -} - -template -__global__ void FastGeluKernel4Bias(const half2 a, const half2 b, const half2 c, const half2 one2, const half2 two2, - int input_length, int bias_length, const float2* input, const float2* bias, - float2* output) { - const int idx = blockIdx.x * TPB + threadIdx.x; - +template +__global__ void FastGeluKernelVec(int input_length, int bias_length, const T* input, const T* bias, + T* output) { + using VecT = aligned_vector; + const T a = T(0.5f); + const T b = T(0.7978845608028654f); + const T c = T(0.035677408136300125f); + const T oneT = T(1.0f); + const T twoT = T(2.0f); + + const int idx = (blockIdx.x * TPB + threadIdx.x) * ILP; if (idx < input_length) { - float2 input_vec = input[idx]; - float2 bias_vec = bias[idx % bias_length]; - float2 output_vec = output[idx]; - - half2* input_half = reinterpret_cast(&input_vec); - half2* bias_half = reinterpret_cast(&bias_vec); - half2* output_half = reinterpret_cast(&output_vec); - - half2 lo_data = input_half[0]; - half2 hi_data = input_half[1]; - half2 lo_bias = bias_half[0]; - half2 hi_bias = bias_half[1]; - - lo_data += lo_bias; - hi_data += hi_bias; - - const half2 lo_u = two2 * lo_data * (c * lo_data * lo_data + b); - const half2 hi_u = two2 * hi_data * (c * hi_data * hi_data + b); - const half2 lo_emu = h2exp(-lo_u); - const half2 hi_emu = h2exp(-hi_u); - const half2 lo_cdf = a + a * (two2/(one2 + lo_emu) - one2); - const half2 hi_cdf = a + a * (two2/(one2 + hi_emu) - one2); - - output_half[0] = lo_data * lo_cdf; - output_half[1] = hi_data * hi_cdf; - - output[idx] = output_vec; - } -} - -template -__global__ void FastGeluKernel4(const half2 a, const half2 b, const half2 c, const half2 one2, const half2 two2, - int input_length, const float2* input, float2* output) { - const int idx = blockIdx.x * TPB + threadIdx.x; - - if (idx < input_length) { - float2 input_vec = input[idx]; - float2 output_vec = output[idx]; - - half2* input_half = reinterpret_cast(&input_vec); - half2* output_half = reinterpret_cast(&output_vec); - - half2 lo_data = input_half[0]; - half2 hi_data = input_half[1]; - - const half2 lo_u = two2 * lo_data * (c * lo_data * lo_data + b); - const half2 hi_u = two2 * hi_data * (c * hi_data * hi_data + b); - const half2 lo_emu = h2exp(-lo_u); - const half2 hi_emu = h2exp(-hi_u); - const half2 lo_cdf = a + a * (two2/(one2 + lo_emu) - one2); - const half2 hi_cdf = a + a * (two2/(one2 + hi_emu) - one2); - - output_half[0] = lo_data * lo_cdf; - output_half[1] = hi_data * hi_cdf; + using VecT = aligned_vector; + T input_v[ILP]; + VecT* input_val = reinterpret_cast(&input_v); + *input_val = *reinterpret_cast(&input[idx]); + T output_v[ILP]; + VecT* output_val = reinterpret_cast(&output_v); + T bias_v[ILP]; + if (bias != nullptr) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[idx % bias_length]); + } - output[idx] = output_vec; + #pragma unroll + for (int i = 0; i < ILP; i++) { + const T x = (bias == nullptr) ? input_v[i] : input_v[i] + bias_v[i]; + const T u = twoT * x * (c * x * x + b); + const T emu = __expf(-u); + const T cdf = a + a * (twoT/(oneT + emu) - oneT); + output_v[i] = x * cdf; + } + *(reinterpret_cast(&output[idx])) = *reinterpret_cast(&output_v[0]); } } template <> -bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int input_length, int bias_length, +bool LaunchFastGeluKernel(hipStream_t stream, int input_length, int bias_length, const float* input, const float* bias, float* output, bool /*use_half2*/) { - constexpr int blockSize = 256; - const int gridSize = (input_length + blockSize - 1) / blockSize; - hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(gridSize), dim3(blockSize), 0, - stream, A, B, C, one, two, input_length, bias_length, input, bias, output); - + constexpr int block_size = 256; + const int grid_size = (input_length + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(grid_size), dim3(block_size), 0, + stream, input_length, bias_length, input, bias, output); return HIP_CALL(hipPeekAtLastError()); } template <> -bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int input_length, int bias_length, +bool LaunchFastGeluKernel(hipStream_t stream, int input_length, int bias_length, const half* input, const half* bias, half* output, bool use_half2) { - constexpr int blockSize = 256; - if (use_half2 && prop.major >= 7 && (0 == (bias_length % 4) || 0 == (bias_length & 1))) { - const half2 A2 = __float2half2_rn(A); - const half2 B2 = __float2half2_rn(B); - const half2 C2 = __float2half2_rn(C); - const half2 one2 = __float2half2_rn(one); - const half2 two2 = __float2half2_rn(two); - if (0 == (bias_length % 4)) { - const int n = input_length / 4; - const int gridSize = (n + blockSize - 1) / blockSize; - const float2* input4 = reinterpret_cast(input); - const float2* bias4 = reinterpret_cast(bias); - float2* output4 = reinterpret_cast(output); - if (bias == nullptr) - hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel4), dim3(gridSize), dim3(blockSize), 0, - stream, A2, B2, C2, one2, two2, n, input4, output4); - else - hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel4Bias), dim3(gridSize), dim3(blockSize), 0, - stream, A2, B2, C2, one2, two2, n, bias_length / 4, input4, bias4, output4); - } else { - const int n = input_length / 2; - const int gridSize = (n + blockSize - 1) / blockSize; - const half2* input2 = reinterpret_cast(input); - const half2* bias2 = reinterpret_cast(bias); - half2* output2 = reinterpret_cast(output); - hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel2), dim3(gridSize), dim3(blockSize), 0, - stream, A2, B2, C2, one2, two2, n, bias_length / 2, input2, bias2, output2); - } + constexpr int block_size = 256; + if (use_half2) { + if (bias != nullptr) { + if (0 == (bias_length % 8) && (input_length >= 3145728)) { // 3145728=8*128*3072 + const int grid_size = (input_length / 8 + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); + } else if (0 == (bias_length % 4)) { + const int grid_size = (input_length / 4 + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); + } else if (0 == (bias_length % 2)) { + const int grid_size = (input_length / 2 + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); + } else { + const int grid_size = (input_length + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); + } + } else { + if (0 == (input_length % 8) && (input_length >= 3145728)) { // 3145728=8*128*3072 + const int grid_size = (input_length / 8 + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); + } else if (0 == (input_length % 4)) { + const int grid_size = (input_length / 4 + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); + } else if (0 == (input_length % 2)) { + const int grid_size = (input_length / 2 + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernelVec), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); + } else { + const int grid_size = (input_length + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); + } + } } else { - const int gridSize = (input_length + blockSize - 1) / blockSize; - const half oneT = half(one); - const half twoT = half(two); - hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(gridSize), dim3(blockSize), 0, - stream, A, B, C, oneT, twoT, input_length, bias_length, input, bias, output); + const int grid_size = (input_length + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(grid_size), + dim3(block_size), 0, stream, input_length, bias_length, + input, bias, output); } - return HIP_CALL(hipPeekAtLastError()); } template <> -bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int input_length, int bias_length, +bool LaunchFastGeluKernel(hipStream_t stream, int input_length, int bias_length, const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { - constexpr int blockSize = 256; - const int gridSize = (input_length + blockSize - 1) / blockSize; - const BFloat16 oneT = BFloat16(one); - const BFloat16 twoT = BFloat16(two); - hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(gridSize), dim3(blockSize), 0, - stream, A, B, C, oneT, twoT, input_length, bias_length, input, bias, output); + constexpr int block_size = 256; + const int grid_size = (input_length + block_size - 1) / block_size; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(grid_size), dim3(block_size), 0, + stream, input_length, bias_length, input, bias, output); return HIP_CALL(hipPeekAtLastError()); } diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h new file mode 100644 index 0000000000000..c626a30418ee4 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Modifications: Remove cudaDeviceProp in LaunchFastGeluKernel. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Licensed under the MIT License. + +#pragma once + + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +bool LaunchFastGeluKernel(hipStream_t stream, int input_length, int bias_length, + const T* input, const T* bias, T* output, bool use_half2); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 302226e1071b9..1210b6e791db6 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -167,7 +167,52 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { // CUDA and ROCm only for Float16 and BFloat16 type. #if defined(USE_CUDA) || defined(USE_ROCM) -TEST(FastGeluTest, FastGeluWithBiasFloat16) { +TEST(FastGeluTest, FastGeluWithBiasFloat16_2) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 2; + + std::vector input_data = { + 0.8f, -0.5f, + 0.5f, 0.2f}; + + std::vector bias_data = { + -0.5f, 0.6f}; + + std::vector output_data = { + 0.1851806640625f, 0.054046630859375f, + 0, 0.63037109375f}; + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector bias_dims = {hidden_size}; + std::vector output_dims = input_dims; + + RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, true, true); +} + +TEST(FastGeluTest, FastGeluWithoutBiasFloat16_2) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 2; + + std::vector input_data = { + 0.8f, -0.5f, + 0.5f, 0.2f}; + + std::vector bias_data = {}; + + std::vector output_data = { + 0.63037109375f, -0.154296875f, + 0.345703125f, 0.11578369140625f}; + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector bias_dims = {}; + std::vector output_dims = input_dims; + + RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); +} + +TEST(FastGeluTest, FastGeluWithBiasFloat16_4) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -190,7 +235,7 @@ TEST(FastGeluTest, FastGeluWithBiasFloat16) { RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, true, true); } -TEST(FastGeluTest, FastGeluWithoutBiasFloat16) { +TEST(FastGeluTest, FastGeluWithoutBiasFloat16_4) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -212,6 +257,51 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16) { RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); } +TEST(FastGeluTest, FastGeluWithBiasFloat16_8) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 8; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, 1.3f, 2.1f, -0.2f, 1.1f, + 0.5f, 0.2f, 0.3f, -0.6f, 3.1f, 2.2f, -1.1f, 0.0f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 1.3f, -1.0f, 0.0f, 3.1f}; + + std::vector output_data = { + 0.18537094f, 0.053982764f, 1.061703f, 3.0973732f, 2.5883462f, 0.95058095f, -0.084148578f, 4.1999736f, + 0.0f, 0.63043171f, 1.3995714f, 1.3995714f, 4.3999906f, 1.061703f, -0.14941895f, 3.0973732f}; + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector bias_dims = {hidden_size}; + std::vector output_dims = input_dims; + + RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, true, true); +} + +TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 8; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, 1.3f, 2.1f, -0.2f, 1.1f, + 0.5f, 0.2f, 0.3f, -0.6f, 3.1f, 2.2f, -1.1f, 0.0f}; + + std::vector bias_data = {}; + + std::vector output_data = { + 0.63043171f, -0.15428598f, 0.0f, 0.84119201f, 1.173929f, 2.062669f, -0.084148578f, 0.95058107f, + 0.345714f, 0.11585142f, 0.18537094f, -0.1645848f, 3.0973732f, 2.1696784f, -0.14941895f, 0.0f}; + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector bias_dims = {hidden_size}; + std::vector output_dims = input_dims; + + RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); +} + TEST(FastGeluTest, FastGeluWithBias_BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -254,7 +344,7 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) { execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); -#endif +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } #endif diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 6ede643f58b09..bbf0dcc2424d1 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -26,6 +26,9 @@ "bert/embed_layer_norm_impl.cu", "bert/embed_layer_norm_impl.h", "bert/fast_gelu_impl.cu", + "bert/fast_gelu_impl.h", + "bert/fast_gelu.cc", + "bert/fast_gelu.h", # 'bert/layer_norm.cuh', "bert/longformer_attention.cc", "bert/longformer_attention.h",