diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index 821a32d6f..9e8e70d8d 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -153,6 +154,21 @@ void ClContext::initBlasClKernels() { blas_kernels_initialized = true; } +void ClContext::initAttentionClKernels() { + if (attention_kernels_initialized) { + ml_logi("ClContext: Default attention kernels already registered and " + "initialized"); + return; + } + + registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl"); + +#ifdef ENABLE_FP16 + registerClKernel(rotary_emb_cl_kernel_fp16_, "rotary_emb_cl_fp16"); +#endif + attention_kernels_initialized = true; +} + const ClContext::SharedPtrClKernel & ClContext::registerClKernel(std::string kernel_string, std::string kernel_name) { diff --git a/nntrainer/cl_context.h b/nntrainer/cl_context.h index ded338bc0..04cc8da54 100644 --- a/nntrainer/cl_context.h +++ b/nntrainer/cl_context.h @@ -211,6 +211,11 @@ class ClContext { */ void initBlasClKernels(); + /** + * @brief Initialize and register all attention OpenCl kernels + */ + void initAttentionClKernels(); + /** * @brief destructor to release opencl commandQueue */ @@ -229,6 +234,9 @@ class ClContext { // flag to check default blas kernels registered or not bool blas_kernels_initialized = false; + // flag to check default attention kernels registered or not + bool attention_kernels_initialized = false; + FactoryMap factory_map; template struct isSupportedHelper; diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp index 85c3331ed..658e2a3d9 100644 --- a/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.cpp @@ -59,12 +59,11 @@ void precompute_freqs(unsigned int dim, unsigned int seq_len, * @param[in] dim hidden dim size * @param[in] from sequence order * @param[in] max_timestep maximum timestep - * @param[in] context layer context to get the resource manager and queue id * * @todo Calling precompute_freqs in finalize to reduce code redundancy. */ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, - unsigned int max_timestep, RunLayerContext &context) { + unsigned int max_timestep) { nntrainer::Tensor out(in.getDim()); float value = 0.0f; float transformed_value = 0.0f; @@ -111,7 +110,7 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_, input_batch_size, input_channels, input_height, input_width, - dim, from, max_timestep, in_size, out_size, context); + dim, from, max_timestep, in_size, out_size); } else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 @@ -123,7 +122,7 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_, input_batch_size, input_channels, input_height, input_width, - dim, from, max_timestep, in_size, out_size, context); + dim, from, max_timestep, in_size, out_size); #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); #endif diff --git a/nntrainer/tensor/cl_operations/attention_kernel_interface.h b/nntrainer/tensor/cl_operations/attention_kernel_interface.h index b287cb0a4..fe9c0f8b0 100644 --- a/nntrainer/tensor/cl_operations/attention_kernel_interface.h +++ b/nntrainer/tensor/cl_operations/attention_kernel_interface.h @@ -14,8 +14,8 @@ #ifndef __ATTENTION_KERNEL_INTERFACE_H__ #define __ATTENTION_KERNEL_INTERFACE_H__ -#include #include +#include namespace nntrainer { @@ -25,10 +25,9 @@ namespace nntrainer { * @param[in] dim hidden dim size * @param[in] from sequence order * @param[in] max_timestep maximum timestep - * @param[in] context layer context to get the resource manager and queue id */ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from, - unsigned int max_timestep, RunLayerContext &context); + unsigned int max_timestep); } // namespace nntrainer #endif /* __ATTENTION_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cl_operations/attention_kernel_strings.h b/nntrainer/tensor/cl_operations/attention_kernel_strings.h new file mode 100644 index 000000000..d58fd7503 --- /dev/null +++ b/nntrainer/tensor/cl_operations/attention_kernel_strings.h @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Yash Singh + * + * @file attention_kernel_strings.h + * @date 8 October 2024 + * @brief All attention OpenCL kernel strings + * @see https://github.com/nnstreamer/nntrainer + * @author Yash Singh + * @bug No known bugs except for NYI items + * + */ + +#ifndef __ATTENTION_KERNEL_STRINGS_H__ +#define __ATTENTION_KERNEL_STRINGS_H__ + +#include + +namespace nntrainer { +static const std::string rotary_emb_cl_kernel_ = R"( + + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void rotary_emb_cl(__global float *input, + __global float *output, + __global float *freqs_cos, + __global float *freqs_sin, + __global float *cos_, + __global float *sin_, + unsigned int batch, + unsigned int channel, + unsigned int height, + unsigned int width, + unsigned int dim, + unsigned int half_, + unsigned int max_timestep, + unsigned int from) { + __global float *cos_ptr = cos_; + __global float *sin_ptr = sin_; + + float value = 0.0f; + float transformed_value = 0.0f; + + unsigned int b = get_global_id(0); + unsigned int c = get_global_id(1); + + if(b < batch && c < channel){ + for (unsigned int h = 0; h < height; h++) { + if (from + h < max_timestep) { + unsigned idx = (from + h)*dim; + for(unsigned int i = idx; i < idx + dim; i++){ + cos_ptr[i - idx] = freqs_cos[i]; + sin_ptr[i - idx] = freqs_sin[i]; + } + } + + for (unsigned int w = 0; w < width; w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = input[b * channel * height * width + c * height * width + h * width + span]; + if (k < half_) { + transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_]; + } else { + transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_]; + } + value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; + output[b * channel * height * width + c * height * width + h * width + span] = value; + } + } + } + } +} +)"; + +#ifdef ENABLE_FP16 +static const std::string rotary_emb_cl_kernel_fp16_ = R"( + + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void rotary_emb_cl_fp16(__global half *input, + __global half *output, + __global float *freqs_cos, + __global float *freqs_sin, + __global float *cos_, + __global float *sin_, + unsigned int batch, + unsigned int channel, + unsigned int height, + unsigned int width, + unsigned int dim, + unsigned int half_, + unsigned int max_timestep, + unsigned int from) { + __global float *cos_ptr = cos_; + __global float *sin_ptr = sin_; + + float value = 0.0f; + float transformed_value = 0.0f; + + unsigned int b = get_global_id(0); + unsigned int c = get_global_id(1); + + if(b < batch && c < channel){ + for (unsigned int h = 0; h < height; h++) { + if (from + h < max_timestep) { + unsigned idx = (from + h)*dim; + for(int i = idx; i < idx + dim; i++ ){ + cos_ptr[i - idx] = freqs_cos[i]; + sin_ptr[i - idx] = freqs_sin[i]; + } + } + + for (unsigned int w = 0; w < width; w = w + dim) { + for (unsigned int k = 0; k < dim; k++) { + unsigned int span = w + k; + value = (float)input[b * channel * height * width + c * height * width + h * width + span]; + if (k < half_) { + transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_]; + } else { + transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_]; + } + value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; + output[b * channel * height * width + c * height * width + h * width + span] = (half)value; + } + } + } + } +} +)"; + +#endif +} // namespace nntrainer +#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */ diff --git a/nntrainer/tensor/cl_operations/attention_kernels.cpp b/nntrainer/tensor/cl_operations/attention_kernels.cpp index 5fd646b7c..388cc0805 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels.cpp @@ -11,66 +11,10 @@ * */ +#include #include namespace nntrainer { -std::string rotary_emb_cl_kernel = R"( - #pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void rotary_emb_cl(__global float *input, - __global float *output, - __global float *freqs_cos, - __global float *freqs_sin, - __global float *cos_, - __global float *sin_, - unsigned int batch, - unsigned int channel, - unsigned int height, - unsigned int width, - unsigned int dim, - unsigned int half_, - unsigned int max_timestep, - unsigned int from) { - __global float *cos_ptr = cos_; - __global float *sin_ptr = sin_; - - float value = 0.0f; - float transformed_value = 0.0f; - - unsigned int b = get_global_id(0); - unsigned int c = get_global_id(1); - - if(b < batch && c < channel){ - for (unsigned int h = 0; h < height; h++) { - if (from + h < max_timestep) { - unsigned idx = (from + h)*dim; - for(unsigned int i = idx; i < idx + dim; i++){ - cos_ptr[i - idx] = freqs_cos[i]; - sin_ptr[i - idx] = freqs_sin[i]; - } - } - - for (unsigned int w = 0; w < width; w = w + dim) { - for (unsigned int k = 0; k < dim; k++) { - unsigned int span = w + k; - value = input[b * channel * height * width + c * height * width + h * width + span]; - if (k < half_) { - transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_]; - } else { - transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_]; - } - value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; - output[b * channel * height * width + c * height * width + h * width + span] = value; - } - } - } - } -} -)"; - -/** - * @brief defining global kernel objects - */ -opencl::Kernel kernel_rotary_emb; void rotary_emb_cl(float *in, float *out, std::vector> freqs_cos, @@ -79,17 +23,16 @@ void rotary_emb_cl(float *in, float *out, unsigned int batch, unsigned int channel, unsigned int height, unsigned int width, unsigned int dim, unsigned int from, unsigned int max_timestep, - unsigned int in_size, unsigned int out_size, - RunLayerContext &context) { + unsigned int in_size, unsigned int out_size) { bool result = false; do { - result = context.clCreateKernel( - rotary_emb_cl_kernel, context.LayerKernel::ROTARY_EMB, kernel_rotary_emb); - if (!result) { - printf("Failed to create kernel for rotary_emb_cl\n"); + ClContext::SharedPtrClKernel kernel_rotaryEmb_ptr = + cl_context_ref.registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl"); + if (!kernel_rotaryEmb_ptr) { break; } + unsigned int cos_dim = cos_.size(); unsigned int sin_dim = sin_.size(); unsigned int freqs_cos_dim = freqs_cos.size(); @@ -103,18 +46,22 @@ void rotary_emb_cl(float *in, float *out, sizeof(float) * freqs_cos_dim * dim; // max_timestep * dim size_t dim6_size = sizeof(float) * freqs_sin_dim * dim; - opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true, + nullptr); - opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr); + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true, + nullptr); - opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr); + opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true, + nullptr); - opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr); + opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true, + nullptr); - opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true, + opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true, nullptr); - opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true, + opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true, nullptr); std::vector freqs_cos_flat; @@ -126,126 +73,130 @@ void rotary_emb_cl(float *in, float *out, freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end()); } - result = inputA.WriteData(context.command_queue_inst_, in); + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); if (!result) { printf("Failed to write input data\n"); break; } - result = inOutRes.WriteData(context.command_queue_inst_, out); + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out); if (!result) { printf("Failed to write output data\n"); break; } - result = freqs_cosBuf.WriteData(context.command_queue_inst_, + result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_, freqs_cos_flat.data()); if (!result) { printf("Failed to write freqs cos data\n"); break; } - result = freqs_sinBuf.WriteData(context.command_queue_inst_, + result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_, freqs_sin_flat.data()); if (!result) { printf("Failed to write freqs sin data\n"); break; } - result = cosBuf.WriteData(context.command_queue_inst_, cos_.data()); + result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data()); if (!result) { printf("Failed to write cos data\n"); break; } - result = sinBuf.WriteData(context.command_queue_inst_, sin_.data()); + result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data()); if (!result) { printf("Failed to write sin data\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + result = + kernel_rotaryEmb_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); if (!result) { printf("Failed to set inputA argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + result = + kernel_rotaryEmb_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); if (!result) { printf("Failed to set inOutRes argument\n"); break; } - result = - kernel_rotary_emb.SetKernelArguments(2, &freqs_cosBuf, sizeof(cl_mem)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(2, &freqs_cosBuf, + sizeof(cl_mem)); if (!result) { printf("Failed to set freqs_cosBuf argument\n"); break; } - result = - kernel_rotary_emb.SetKernelArguments(3, &freqs_sinBuf, sizeof(cl_mem)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(3, &freqs_sinBuf, + sizeof(cl_mem)); if (!result) { printf("Failed to set freqs_sinBuf argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); + result = + kernel_rotaryEmb_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); if (!result) { printf("Failed to set cosBuf argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); + result = + kernel_rotaryEmb_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); if (!result) { printf("Failed to set sinBuf argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(6, &batch, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(6, &batch, sizeof(int)); if (!result) { printf("Failed to set batch argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(7, &channel, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(7, &channel, sizeof(int)); if (!result) { printf("Failed to set channel argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(8, &height, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(8, &height, sizeof(int)); if (!result) { printf("Failed to set height argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(9, &width, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(9, &width, sizeof(int)); if (!result) { printf("Failed to set width argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(10, &dim, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(10, &dim, sizeof(int)); if (!result) { printf("Failed to set dim argument\n"); break; } unsigned int half_ = dim / 2; - result = kernel_rotary_emb.SetKernelArguments(11, &half_, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(11, &half_, sizeof(int)); if (!result) { printf("Failed to set half argument\n"); break; } result = - kernel_rotary_emb.SetKernelArguments(12, &max_timestep, sizeof(int)); + kernel_rotaryEmb_ptr->SetKernelArguments(12, &max_timestep, sizeof(int)); if (!result) { printf("Failed to set timestamp argument\n"); break; } - result = kernel_rotary_emb.SetKernelArguments(13, &from, sizeof(int)); + result = kernel_rotaryEmb_ptr->SetKernelArguments(13, &from, sizeof(int)); if (!result) { printf("Failed to set from argument\n"); break; @@ -253,14 +204,14 @@ void rotary_emb_cl(float *in, float *out, const int work_groups_count[3] = {(int)batch, (int)channel, 1}; const int work_group_size[3] = {32, 32, 1}; // test-value - result = context.command_queue_inst_.DispatchCommand( - kernel_rotary_emb, work_groups_count, work_group_size); + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_rotaryEmb_ptr, work_groups_count, work_group_size); if (!result) { printf("Failed to dispatch command\n"); break; } - result = inOutRes.ReadData(context.command_queue_inst_, out); + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out); if (!result) { printf("Failed to read data\n"); break; diff --git a/nntrainer/tensor/cl_operations/attention_kernels.h b/nntrainer/tensor/cl_operations/attention_kernels.h index 97e2a98ce..37a3a4428 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels.h +++ b/nntrainer/tensor/cl_operations/attention_kernels.h @@ -14,17 +14,15 @@ #ifndef __ATTENTION_KERNELS_H__ #define __ATTENTION_KERNELS_H__ -#include +#include #include #include #include namespace nntrainer { -/** - * @brief declaring global kernel objects - */ -extern opencl::Kernel kernel_rotary_emb; +// get global cl_context to use in kernels +static ClContext cl_context_ref; /** * @brief Rotary Embedding process @@ -43,7 +41,6 @@ extern opencl::Kernel kernel_rotary_emb; * @param[in] max_timestep max timestep * @param[in] in_size size of input * @param[in] out_size size of output - * @param[in] context RunLayerContext reference */ void rotary_emb_cl(float *in, float *out, std::vector> freqs_cos, @@ -52,14 +49,9 @@ void rotary_emb_cl(float *in, float *out, unsigned int batch, unsigned int channel, unsigned int height, unsigned int width, unsigned int dim, unsigned int from, unsigned int max_timestamp, - unsigned int in_size, unsigned int out_size, - RunLayerContext &context); + unsigned int in_size, unsigned int out_size); #ifdef ENABLE_FP16 -/** - * @brief declaring global fp16 kernel objects - */ -extern opencl::Kernel kernel_rotary_emb_fp16; /** * @brief Rotary Embedding process @@ -78,7 +70,6 @@ extern opencl::Kernel kernel_rotary_emb_fp16; * @param[in] max_timestep max timestep * @param[in] in_size size of input * @param[in] out_size size of output - * @param[in] context RunLayerContext reference */ void rotary_emb_cl(__fp16 *in, __fp16 *out, std::vector> freqs_cos, @@ -87,8 +78,7 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, unsigned int batch, unsigned int channel, unsigned int height, unsigned int width, unsigned int dim, unsigned int from, unsigned int max_timestamp, - unsigned int in_size, unsigned int out_size, - RunLayerContext &context); + unsigned int in_size, unsigned int out_size); #endif diff --git a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp index 7c2c99502..c1284b0a9 100644 --- a/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/attention_kernels_fp16.cpp @@ -11,66 +11,10 @@ * */ +#include #include namespace nntrainer { -std::string rotary_emb_cl_kernel_fp16 = R"( - #pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void rotary_emb_cl_fp16(__global half *input, - __global half *output, - __global float *freqs_cos, - __global float *freqs_sin, - __global float *cos_, - __global float *sin_, - unsigned int batch, - unsigned int channel, - unsigned int height, - unsigned int width, - unsigned int dim, - unsigned int half_, - unsigned int max_timestep, - unsigned int from) { - __global float *cos_ptr = cos_; - __global float *sin_ptr = sin_; - - float value = 0.0f; - float transformed_value = 0.0f; - - unsigned int b = get_global_id(0); - unsigned int c = get_global_id(1); - - if(b < batch && c < channel){ - for (unsigned int h = 0; h < height; h++) { - if (from + h < max_timestep) { - unsigned idx = (from + h)*dim; - for(int i = idx; i < idx + dim; i++ ){ - cos_ptr[i - idx] = freqs_cos[i]; - sin_ptr[i - idx] = freqs_sin[i]; - } - } - - for (unsigned int w = 0; w < width; w = w + dim) { - for (unsigned int k = 0; k < dim; k++) { - unsigned int span = w + k; - value = (float)input[b * channel * height * width + c * height * width + h * width + span]; - if (k < half_) { - transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_]; - } else { - transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_]; - } - value = value * cos_ptr[k] + transformed_value * sin_ptr[k]; - output[b * channel * height * width + c * height * width + h * width + span] = (half)value; - } - } - } - } -} -)"; - -/** - * @brief defining global kernel objects - */ -opencl::Kernel kernel_rotary_emb_fp16; void rotary_emb_cl(__fp16 *in, __fp16 *out, std::vector> freqs_cos, @@ -79,16 +23,14 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, unsigned int batch, unsigned int channel, unsigned int height, unsigned int width, unsigned int dim, unsigned int from, unsigned int max_timestep, - unsigned int in_size, unsigned int out_size, - RunLayerContext &context) { + unsigned int in_size, unsigned int out_size) { bool result = false; do { - result = context.clCreateKernel(rotary_emb_cl_kernel_fp16, - context.LayerKernel::ROTARY_EMB_FP16, - kernel_rotary_emb_fp16); - if (!result) { - printf("Failed to create kernel for rotary_emb_cl\n"); + ClContext::SharedPtrClKernel kernel_rotaryEmb_fp16_ptr = + cl_context_ref.registerClKernel(rotary_emb_cl_kernel_fp16_, + "rotary_emb_cl_fp16"); + if (!kernel_rotaryEmb_fp16_ptr) { break; } @@ -104,18 +46,22 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, size_t dim5_size = sizeof(float) * freqs_cos_dim * dim; size_t dim6_size = sizeof(float) * freqs_sin_dim * dim; - opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr); + opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true, + nullptr); - opencl::Buffer inOutRes(context.context_inst_, dim2_size, true, nullptr); + opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true, + nullptr); - opencl::Buffer cosBuf(context.context_inst_, dim3_size, true, nullptr); + opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true, + nullptr); - opencl::Buffer sinBuf(context.context_inst_, dim4_size, true, nullptr); + opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true, + nullptr); - opencl::Buffer freqs_cosBuf(context.context_inst_, dim5_size, true, + opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true, nullptr); - opencl::Buffer freqs_sinBuf(context.context_inst_, dim6_size, true, + opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true, nullptr); std::vector freqs_cos_flat; @@ -127,131 +73,137 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end()); } - result = inputA.WriteData(context.command_queue_inst_, in); + result = inputA.WriteData(cl_context_ref.command_queue_inst_, in); if (!result) { printf("Failed to write input data\n"); break; } - result = inOutRes.WriteData(context.command_queue_inst_, out); + result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out); if (!result) { printf("Failed to write output data\n"); break; } - result = freqs_cosBuf.WriteData(context.command_queue_inst_, + result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_, freqs_cos_flat.data()); if (!result) { printf("Failed to write freqs cos data\n"); break; } - result = freqs_sinBuf.WriteData(context.command_queue_inst_, + result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_, freqs_sin_flat.data()); if (!result) { printf("Failed to write freqs sin data\n"); break; } - result = cosBuf.WriteData(context.command_queue_inst_, cos_.data()); + result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data()); if (!result) { printf("Failed to write cos data\n"); break; } - result = sinBuf.WriteData(context.command_queue_inst_, sin_.data()); + result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data()); if (!result) { printf("Failed to write sin data\n"); break; } result = - kernel_rotary_emb_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem)); + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem)); if (!result) { printf("Failed to set inputA argument\n"); break; } - result = - kernel_rotary_emb_fp16.SetKernelArguments(1, &inOutRes, sizeof(cl_mem)); + result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(1, &inOutRes, + sizeof(cl_mem)); if (!result) { printf("Failed to set inOutRes argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(2, &freqs_cosBuf, - sizeof(cl_mem)); + result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(2, &freqs_cosBuf, + sizeof(cl_mem)); if (!result) { printf("Failed to set freqs_cosBuf argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(3, &freqs_sinBuf, - sizeof(cl_mem)); + result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(3, &freqs_sinBuf, + sizeof(cl_mem)); if (!result) { printf("Failed to set freqs_sinBuf argument\n"); break; } result = - kernel_rotary_emb_fp16.SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem)); if (!result) { printf("Failed to set cosBuf argument\n"); break; } result = - kernel_rotary_emb_fp16.SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem)); if (!result) { printf("Failed to set sinBuf argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(6, &batch, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(6, &batch, sizeof(int)); if (!result) { printf("Failed to set batch argument\n"); break; } result = - kernel_rotary_emb_fp16.SetKernelArguments(7, &channel, sizeof(int)); + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(7, &channel, sizeof(int)); if (!result) { printf("Failed to set channel argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(8, &height, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(8, &height, sizeof(int)); if (!result) { printf("Failed to set height argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(9, &width, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(9, &width, sizeof(int)); if (!result) { printf("Failed to set width argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(10, &dim, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(10, &dim, sizeof(int)); if (!result) { printf("Failed to set dim argument\n"); break; } unsigned int half_ = dim / 2; - result = kernel_rotary_emb_fp16.SetKernelArguments(11, &half_, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(11, &half_, sizeof(int)); if (!result) { printf("Failed to set half argument\n"); break; } - result = - kernel_rotary_emb_fp16.SetKernelArguments(12, &max_timestep, sizeof(int)); + result = kernel_rotaryEmb_fp16_ptr->SetKernelArguments(12, &max_timestep, + sizeof(int)); if (!result) { printf("Failed to set timestamp argument\n"); break; } - result = kernel_rotary_emb_fp16.SetKernelArguments(13, &from, sizeof(int)); + result = + kernel_rotaryEmb_fp16_ptr->SetKernelArguments(13, &from, sizeof(int)); if (!result) { printf("Failed to set from argument\n"); break; @@ -259,14 +211,14 @@ void rotary_emb_cl(__fp16 *in, __fp16 *out, const int work_groups_count[3] = {(int)batch, (int)channel, 1}; const int work_group_size[3] = {32, 32, 1}; // test-value - result = context.command_queue_inst_.DispatchCommand( - kernel_rotary_emb_fp16, work_groups_count, work_group_size); + result = cl_context_ref.command_queue_inst_.DispatchCommand( + kernel_rotaryEmb_fp16_ptr, work_groups_count, work_group_size); if (!result) { printf("Failed to dispatch command\n"); break; } - result = inOutRes.ReadData(context.command_queue_inst_, out); + result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out); if (!result) { printf("Failed to read data\n"); break; diff --git a/nntrainer/tensor/cl_operations/meson.build b/nntrainer/tensor/cl_operations/meson.build index 3f186ec64..a1b9b795b 100644 --- a/nntrainer/tensor/cl_operations/meson.build +++ b/nntrainer/tensor/cl_operations/meson.build @@ -9,6 +9,7 @@ cl_op_headers = [ 'blas_kernel_interface.h', 'blas_kernel_strings.h', 'attention_kernel_interface.h', + 'attention_kernel_strings.h', ] if get_option('enable-fp16') diff --git a/test/unittest/unittest_attention_kernels_cl.cpp b/test/unittest/unittest_attention_kernels_cl.cpp index d2a26cc9d..a95937446 100644 --- a/test/unittest/unittest_attention_kernels_cl.cpp +++ b/test/unittest/unittest_attention_kernels_cl.cpp @@ -29,16 +29,13 @@ using namespace nntrainer; -static RunLayerContext setUpGpuContext() { - +static void setUpGpuContext() { auto &ac = nntrainer::ClContext::Global(); - auto rc = RunLayerContext(); - - return rc; + ac.initAttentionClKernels(); } TEST(attention_kernels, rotary_emb_kernel_FP32) { - RunLayerContext rc = setUpGpuContext(); + setUpGpuContext(); int batch = 1; int channel = 1; @@ -65,7 +62,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32) { B_fp32.copy(A_fp32); - apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc); + apply_rotary_emb_cl(A_fp32, dim, from, max_timestep); apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep); float mseErrorNeon_fp32 = @@ -81,7 +78,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32) { } TEST(attention_kernels, rotary_emb_kernel_FP32_case2) { - RunLayerContext rc = setUpGpuContext(); + setUpGpuContext(); int batch = 4; int channel = 4; @@ -108,7 +105,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32_case2) { B_fp32.copy(A_fp32); - apply_rotary_emb_cl(A_fp32, dim, from, max_timestep, rc); + apply_rotary_emb_cl(A_fp32, dim, from, max_timestep); apply_rotary_emb_tensor(B_fp32, dim, from, max_timestep); float mseErrorNeon_fp32 = @@ -124,7 +121,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP32_case2) { } TEST(attention_kernels, rotary_emb_kernel_FP16) { - RunLayerContext rc = setUpGpuContext(); + setUpGpuContext(); int batch = 1; int channel = 1; @@ -150,7 +147,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16) { B_fp16.copy(A_fp16); - apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc); + apply_rotary_emb_cl(A_fp16, dim, from, max_timestep); apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep); float mseErrorNeon_fp16 = mse<__fp16>( @@ -166,7 +163,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16) { } TEST(attention_kernels, rotary_emb_kernel_FP16_case2) { - RunLayerContext rc = setUpGpuContext(); + setUpGpuContext(); int batch = 4; int channel = 4; @@ -192,7 +189,7 @@ TEST(attention_kernels, rotary_emb_kernel_FP16_case2) { B_fp16.copy(A_fp16); - apply_rotary_emb_cl(A_fp16, dim, from, max_timestep, rc); + apply_rotary_emb_cl(A_fp16, dim, from, max_timestep); apply_rotary_emb_tensor(B_fp16, dim, from, max_timestep); float mseErrorNeon_fp16 = mse<__fp16>(