Skip to content

Commit

Permalink
[GPU/Enhance] Registering Attention kernels and removind cl_context d…
Browse files Browse the repository at this point in the history
…ependency

Added registerCLKernel function to register custom OpenCL kernels as well as in-house kernels.
Modified attention kernels to remove cl_context related dependencies.
Added initAttentionCLKernels function to register default attention kernels.
Modified unittest to remove layer_context dependency
attention_kernel_strings.h added to handle attention kernels at one place.
Rebased the PR with current log.

Signed-off-by: Yash Singh <yash.singh@samsung.com>
  • Loading branch information
yashSingh0723 committed Oct 8, 2024
1 parent d8dc404 commit 3e0e6af
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 226 deletions.
16 changes: 16 additions & 0 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <addition_layer_cl.h>
#include <attention_kernel_strings.h>
#include <blas_kernel_strings.h>
#include <cl_context.h>
#include <concat_cl.h>
Expand Down Expand Up @@ -153,6 +154,21 @@ void ClContext::initBlasClKernels() {
blas_kernels_initialized = true;
}

void ClContext::initAttentionClKernels() {
if (attention_kernels_initialized) {
ml_logi("ClContext: Default attention kernels already registered and "
"initialized");
return;
}

registerClKernel(rotary_emb_cl_kernel_, "rotary_emb_cl");

#ifdef ENABLE_FP16
registerClKernel(rotary_emb_cl_kernel_fp16_, "rotary_emb_cl_fp16");
#endif
attention_kernels_initialized = true;
}

const ClContext::SharedPtrClKernel &
ClContext::registerClKernel(std::string kernel_string,
std::string kernel_name) {
Expand Down
8 changes: 8 additions & 0 deletions nntrainer/cl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ class ClContext {
*/
void initBlasClKernels();

/**
* @brief Initialize and register all attention OpenCl kernels
*/
void initAttentionClKernels();

/**
* @brief destructor to release opencl commandQueue
*/
Expand All @@ -229,6 +234,9 @@ class ClContext {
// flag to check default blas kernels registered or not
bool blas_kernels_initialized = false;

// flag to check default attention kernels registered or not
bool attention_kernels_initialized = false;

FactoryMap<nntrainer::Layer> factory_map;

template <typename Args, typename T> struct isSupportedHelper;
Expand Down
7 changes: 3 additions & 4 deletions nntrainer/tensor/cl_operations/attention_kernel_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ void precompute_freqs(unsigned int dim, unsigned int seq_len,
* @param[in] dim hidden dim size
* @param[in] from sequence order
* @param[in] max_timestep maximum timestep
* @param[in] context layer context to get the resource manager and queue id
*
* @todo Calling precompute_freqs in finalize to reduce code redundancy.
*/
void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
unsigned int max_timestep, RunLayerContext &context) {
unsigned int max_timestep) {
nntrainer::Tensor out(in.getDim());
float value = 0.0f;
float transformed_value = 0.0f;
Expand Down Expand Up @@ -111,7 +110,7 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,

rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
input_batch_size, input_channels, input_height, input_width,
dim, from, max_timestep, in_size, out_size, context);
dim, from, max_timestep, in_size, out_size);

} else if (in.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
Expand All @@ -123,7 +122,7 @@ void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,

rotary_emb_cl(data, rdata, freqs_cos, freqs_sin, cos_, sin_,
input_batch_size, input_channels, input_height, input_width,
dim, from, max_timestep, in_size, out_size, context);
dim, from, max_timestep, in_size, out_size);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
Expand Down
5 changes: 2 additions & 3 deletions nntrainer/tensor/cl_operations/attention_kernel_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
#ifndef __ATTENTION_KERNEL_INTERFACE_H__
#define __ATTENTION_KERNEL_INTERFACE_H__

#include <layer_context.h>
#include <string>
#include <tensor.h>

namespace nntrainer {

Expand All @@ -25,10 +25,9 @@ namespace nntrainer {
* @param[in] dim hidden dim size
* @param[in] from sequence order
* @param[in] max_timestep maximum timestep
* @param[in] context layer context to get the resource manager and queue id
*/
void apply_rotary_emb_cl(Tensor &in, unsigned int dim, unsigned int from,
unsigned int max_timestep, RunLayerContext &context);
unsigned int max_timestep);

} // namespace nntrainer
#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
133 changes: 133 additions & 0 deletions nntrainer/tensor/cl_operations/attention_kernel_strings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Yash Singh <yash.singh@samsung.com>
*
* @file attention_kernel_strings.h
* @date 8 October 2024
* @brief All attention OpenCL kernel strings
* @see https://github.com/nnstreamer/nntrainer
* @author Yash Singh <yash.singh@samsung.com>
* @bug No known bugs except for NYI items
*
*/

#ifndef __ATTENTION_KERNEL_STRINGS_H__
#define __ATTENTION_KERNEL_STRINGS_H__

#include <string>

namespace nntrainer {
static const std::string rotary_emb_cl_kernel_ = R"(
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void rotary_emb_cl(__global float *input,
__global float *output,
__global float *freqs_cos,
__global float *freqs_sin,
__global float *cos_,
__global float *sin_,
unsigned int batch,
unsigned int channel,
unsigned int height,
unsigned int width,
unsigned int dim,
unsigned int half_,
unsigned int max_timestep,
unsigned int from) {
__global float *cos_ptr = cos_;
__global float *sin_ptr = sin_;
float value = 0.0f;
float transformed_value = 0.0f;
unsigned int b = get_global_id(0);
unsigned int c = get_global_id(1);
if(b < batch && c < channel){
for (unsigned int h = 0; h < height; h++) {
if (from + h < max_timestep) {
unsigned idx = (from + h)*dim;
for(unsigned int i = idx; i < idx + dim; i++){
cos_ptr[i - idx] = freqs_cos[i];
sin_ptr[i - idx] = freqs_sin[i];
}
}
for (unsigned int w = 0; w < width; w = w + dim) {
for (unsigned int k = 0; k < dim; k++) {
unsigned int span = w + k;
value = input[b * channel * height * width + c * height * width + h * width + span];
if (k < half_) {
transformed_value = -1.0f * input[b * channel * height * width + c * height * width + h * width + span + half_];
} else {
transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
}
value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
output[b * channel * height * width + c * height * width + h * width + span] = value;
}
}
}
}
}
)";

#ifdef ENABLE_FP16
static const std::string rotary_emb_cl_kernel_fp16_ = R"(
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void rotary_emb_cl_fp16(__global half *input,
__global half *output,
__global float *freqs_cos,
__global float *freqs_sin,
__global float *cos_,
__global float *sin_,
unsigned int batch,
unsigned int channel,
unsigned int height,
unsigned int width,
unsigned int dim,
unsigned int half_,
unsigned int max_timestep,
unsigned int from) {
__global float *cos_ptr = cos_;
__global float *sin_ptr = sin_;
float value = 0.0f;
float transformed_value = 0.0f;
unsigned int b = get_global_id(0);
unsigned int c = get_global_id(1);
if(b < batch && c < channel){
for (unsigned int h = 0; h < height; h++) {
if (from + h < max_timestep) {
unsigned idx = (from + h)*dim;
for(int i = idx; i < idx + dim; i++ ){
cos_ptr[i - idx] = freqs_cos[i];
sin_ptr[i - idx] = freqs_sin[i];
}
}
for (unsigned int w = 0; w < width; w = w + dim) {
for (unsigned int k = 0; k < dim; k++) {
unsigned int span = w + k;
value = (float)input[b * channel * height * width + c * height * width + h * width + span];
if (k < half_) {
transformed_value = -1.0f * (float)input[b * channel * height * width + c * height * width + h * width + span + half_];
} else {
transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
}
value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
}
}
}
}
}
)";

#endif
} // namespace nntrainer
#endif /* __ATTENTION_KERNEL_INTERFACE_H__ */
Loading

0 comments on commit 3e0e6af

Please sign in to comment.