diff --git a/nntrainer/tensor/blas_interface.cpp b/nntrainer/tensor/blas_interface.cpp index d4d21d4fc5..47a9a7546f 100644 --- a/nntrainer/tensor/blas_interface.cpp +++ b/nntrainer/tensor/blas_interface.cpp @@ -24,6 +24,12 @@ #include #endif +#ifdef USE_BLAS +extern "C" { +#include +} +#endif + #include #define sgemv_loop(ci, cj, cM, cN) \ @@ -57,23 +63,23 @@ Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX]; \ } while (0); -#define hgemm_loop() \ - do { \ - for (unsigned int m = 0; m < M; ++m) { \ - for (unsigned int n = 0; n < N; ++n) { \ - float c = 0; \ - _FP16 c_old = C[m * ldc + n]; \ - for (unsigned int k = 0; k < K; ++k) { \ - _FP16 a, b; \ - a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]); \ - b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]); \ - c += static_cast(a * b); \ - } \ - C[m * ldc + n] = static_cast<_FP16>(alpha * c); \ - if (beta != 0.0) \ - C[m * ldc + n] += static_cast<_FP16>(beta) * c_old; \ - } \ - } \ +#define hgemm_loop() \ + do { \ + for (unsigned int m = 0; m < M; ++m) { \ + for (unsigned int n = 0; n < N; ++n) { \ + float c = 0; \ + _FP16 c_old = C[m * ldc + n]; \ + for (unsigned int k = 0; k < K; ++k) { \ + _FP16 a, b; \ + a = ((TransA) ? A[k * lda + m] : A[m * lda + k]); \ + b = ((TransB) ? B[n * ldb + k] : B[k * ldb + n]); \ + c += static_cast(a * b); \ + } \ + C[m * ldc + n] = static_cast<_FP16>(alpha * c); \ + if (beta != 0.0) \ + C[m * ldc + n] += static_cast<_FP16>(beta) * c_old; \ + } \ + } \ } while (0); namespace nntrainer { @@ -94,7 +100,7 @@ static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X, const int incX, _FP16 *Y, const int incY) { if (incX < 0 or incY < 0) throw std::invalid_argument( - "Error: negative inc not supported without cblas"); + "Error: negative inc not supported"); #if (defined USE__FP16 && USE_NEON) // USE__FP16 is defined when platform is android @@ -108,22 +114,22 @@ static void saxpy_FP16(const unsigned int N, const float alpha, const _FP16 *X, #endif } -static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, +static void sgemv_FP16(const unsigned int TStorageOrder, bool TransA, const unsigned int M, const unsigned int N, const float alpha, const _FP16 *A, const unsigned int lda, const _FP16 *X, const int incX, const float beta, _FP16 *Y, const int incY) { #if (defined USE__FP16 && USE_NEON) - if (TransA == CblasTrans) { + if (TransA) { nntrainer::neon::hgemv_transpose(A, X, Y, M, N, alpha, beta); } else { nntrainer::neon::hgemv(A, X, Y, M, N, alpha, beta); } #else unsigned int lenX = - (TransA == CblasTrans) ? 1 + (M - 1) * abs(incX) : 1 + (N - 1) * abs(incX); + (TransA) ? 1 + (M - 1) * abs(incX) : 1 + (N - 1) * abs(incX); unsigned int lenY = - (TransA == CblasTrans) ? 1 + (N - 1) * abs(incY) : 1 + (M - 1) * abs(incY); + (TransA) ? 1 + (N - 1) * abs(incY) : 1 + (M - 1) * abs(incY); float *A_ = new float[M * N]; float *X_ = new float[lenX]; @@ -317,18 +323,20 @@ static _FP16 snrm2_FP16(const unsigned int N, const _FP16 *X, const int incX) { return sum; } -static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, const unsigned int M, - const unsigned int N, const unsigned int K, - const float alpha, const _FP16 *A, +static void sgemm_FP16(const unsigned int TStorageOrder, bool TransA, + bool TransB, const unsigned int M, const unsigned int N, + const unsigned int K, const float alpha, const _FP16 *A, const unsigned int lda, const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C, const unsigned int ldc) { #if (defined USE__FP16 && USE_NEON) - nntrainer::neon::custom_hgemm(A, B, C, M, N, K, alpha, beta, - TransA == CblasTrans, TransB == CblasTrans); + nntrainer::neon::custom_hgemm(A, B, C, M, N, K, alpha, beta, TransA, TransB); #else + CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transB = TransB ? CblasTrans : CblasNoTrans; + CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor; + float *A_ = new float[M * K]; float *B_ = new float[N * K]; float *C_ = new float[M * N]; @@ -336,7 +344,7 @@ static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, scopy(M * K, A, 1, A_, 1); scopy(N * K, B, 1, B_, 1); scopy(M * N, C, 1, C_, 1); - sgemm(order, TransA, TransB, M, N, K, alpha, A_, lda, B_, ldb, beta, C_, ldc); + sgemm(order, transA, transB, M, N, K, alpha, A_, lda, B_, ldb, beta, C_, ldc); scopy(M * N, C_, 1, C, 1); delete[] A_; @@ -381,13 +389,13 @@ void saxpy(const unsigned int N, const float alpha, const _FP16 *X, saxpy_FP16(N, alpha, X, incX, Y, incY); } -void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, +void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB, const unsigned int M, const unsigned int N, const unsigned int K, const float alpha, const _FP16 *A, const unsigned int lda, const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C, const unsigned int ldc) { - sgemm_FP16(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, - ldc); + sgemm_FP16(TStorageOrder, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } void scopy(const unsigned int N, const _FP16 *X, const int incX, _FP16 *Y, @@ -520,11 +528,12 @@ _FP16 sdot(const unsigned int N, const _FP16 *X, const unsigned int incX, return sdot_FP16(N, X, incX, Y, incY); } -void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, +void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M, const unsigned int N, const float alpha, const _FP16 *A, const unsigned int lda, const _FP16 *X, const int incX, const float beta, _FP16 *Y, const int incY) { - sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + sgemv_FP16(TStorageOrder, TransA, M, N, alpha, A, lda, X, incX, beta, Y, + incY); } unsigned int isamax(const unsigned int N, const _FP16 *X, const int incX) { @@ -558,12 +567,12 @@ static void saxpy_raw(const unsigned int N, const float alpha, const float *X, const int incX, float *Y, const int incY) { if (incX < 0 or incY < 0) throw std::invalid_argument( - "Error: negative inc not supported without cblas"); + "Error: negative inc not supported"); for (unsigned int i = 0; i < N; ++i) Y[i * incY] = Y[i * incY] + X[i * incX] * alpha; } -static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, +static void sgemv_raw(const unsigned int TStorageOrder, bool TransA, const unsigned int M, const unsigned int N, const float alpha, const float *A, const unsigned int lda, const float *X, const int incX, const float beta, @@ -572,7 +581,7 @@ static void sgemv_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, unsigned int incy = abs(incY); unsigned int incx = abs(incX); - if (TransA == CblasTrans) { + if (TransA) { sgemv_loop(i, j, N, M); } else { sgemv_loop(j, i, M, N); @@ -618,12 +627,12 @@ static float snrm2_raw(const unsigned int N, const float *X, const int incX) { return sqrt(sum); } -static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, const unsigned int M, - const unsigned int N, const unsigned int K, - const float alpha, const float *A, const unsigned int lda, - const float *B, const unsigned int ldb, const float beta, - float *C, const unsigned int ldc) { +static void sgemm_raw(const unsigned int TStorageOrder, bool TransA, + bool TransB, const unsigned int M, const unsigned int N, + const unsigned int K, const float alpha, const float *A, + const unsigned int lda, const float *B, + const unsigned int ldb, const float beta, float *C, + const unsigned int ldc) { for (unsigned int m = 0; m < M; ++m) { for (unsigned int n = 0; n < N; ++n) { @@ -631,8 +640,8 @@ static void sgemm_raw(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, float c_old = C[m * ldc + n]; for (unsigned int k = 0; k < K; ++k) { float a, b; - a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]); - b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]); + a = ((TransA) ? A[k * lda + m] : A[m * lda + k]); + b = ((TransB) ? B[n * ldb + k] : B[k * ldb + n]); c += a * b; } C[m * ldc + n] = alpha * c; @@ -729,12 +738,11 @@ void saxpy(const unsigned int N, const float alpha, const float *X, #endif } -void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, +void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB, const unsigned int M, const unsigned int N, const unsigned int K, const float alpha, const void *A, const unsigned int lda, const void *B, const unsigned int ldb, const float beta, void *C, const unsigned int ldc, ml::train::TensorDim::DataType d_type) { - if (d_type == ml::train::TensorDim::DataType::FP32) { #ifdef USE_CUBLAS int devID = 0; @@ -755,10 +763,8 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, cublasHandle_t handle; cublasCreate(&handle); - cublasOperation_t transA = - (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t transB = - (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transA = (TransA) ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transB = (TransB) ? CUBLAS_OP_T : CUBLAS_OP_N; cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta, d_C, N); @@ -770,33 +776,35 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, #ifdef BLAS_NUM_THREADS openblas_set_num_threads(BLAS_NUM_THREADS); #endif - + CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transB = TransB ? CblasTrans : CblasNoTrans; + CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor; cblas_sgemm( - order, TransA, TransB, M, N, K, alpha, static_cast(A), lda, + order, transA, transB, M, N, K, alpha, static_cast(A), lda, static_cast(B), ldb, beta, static_cast(C), ldc); #else - sgemm_raw(order, TransA, TransB, M, N, K, alpha, + sgemm_raw(TStorageOrder, TransA, TransB, M, N, K, alpha, static_cast(A), lda, static_cast(B), ldb, beta, static_cast(C), ldc); #endif } else if (d_type == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 - sgemm_FP16( - order, TransA, TransB, M, N, K, alpha, static_cast(A), lda, - static_cast(B), ldb, beta, static_cast<_FP16 *>(C), ldc); + sgemm_FP16(TStorageOrder, TransA, TransB, M, N, K, alpha, + static_cast(A), lda, + static_cast(B), ldb, beta, + static_cast<_FP16 *>(C), ldc); #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); #endif } } // namespace nntrainer -void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, +void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB, const unsigned int M, const unsigned int N, const unsigned int K, const float alpha, const float *A, const unsigned int lda, const float *B, const unsigned int ldb, const float beta, float *C, const unsigned int ldc) { - #ifdef USE_CUBLAS int devID = 0; cudaDeviceProp deviceProp; @@ -816,8 +824,8 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, cublasHandle_t handle; cublasCreate(&handle); - cublasOperation_t transA = (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t transB = (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transA = (TransA) ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transB = (TransB) ? CUBLAS_OP_T : CUBLAS_OP_N; cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta, d_C, N); @@ -827,11 +835,14 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, #ifdef BLAS_NUM_THREADS openblas_set_num_threads(BLAS_NUM_THREADS); #endif - cblas_sgemm(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, + CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transB = TransB ? CblasTrans : CblasNoTrans; + CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor; + cblas_sgemm(order, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); #else - sgemm_raw(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, - ldc); + sgemm_raw(TStorageOrder, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, + C, ldc); #endif } @@ -927,37 +938,39 @@ float sdot(const unsigned int N, const float *X, const unsigned int incX, #endif } -void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, +void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M, const unsigned int N, const float alpha, const void *A, const unsigned int lda, const void *X, const int incX, const float beta, void *Y, const int incY, ml::train::TensorDim::DataType d_type) { + if (d_type == ml::train::TensorDim::DataType::FP32) { #ifdef USE_BLAS #ifdef BLAS_NUM_THREADS openblas_set_num_threads(BLAS_NUM_THREADS); #endif + CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans; + CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor; return cblas_sgemv( - order, TransA, M, N, alpha, static_cast(A), lda, + order, transA, M, N, alpha, static_cast(A), lda, static_cast(X), incX, beta, static_cast(Y), incY); #else - - return sgemv_raw(order, TransA, M, N, alpha, static_cast(A), - lda, static_cast(X), incX, beta, - static_cast(Y), incY); + return sgemv_raw( + TStorageOrder, TransA, M, N, alpha, static_cast(A), lda, + static_cast(X), incX, beta, static_cast(Y), incY); #endif } else if (d_type == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 - return sgemv_FP16(order, TransA, M, N, alpha, static_cast(A), - lda, static_cast(X), incX, beta, - static_cast<_FP16 *>(Y), incY); + return sgemv_FP16( + TStorageOrder, TransA, M, N, alpha, static_cast(A), lda, + static_cast(X), incX, beta, static_cast<_FP16 *>(Y), incY); #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); #endif } } -void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, +void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M, const unsigned int N, const float alpha, const float *A, const unsigned int lda, const float *X, const int incX, const float beta, float *Y, const int incY) { @@ -965,10 +978,13 @@ void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, #ifdef BLAS_NUM_THREADS openblas_set_num_threads(BLAS_NUM_THREADS); #endif - return cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, + CBLAS_TRANSPOSE transA = TransA ? CblasTrans : CblasNoTrans; + CBLAS_ORDER order = TStorageOrder ? CblasColMajor : CblasRowMajor; + return cblas_sgemv(order, transA, M, N, alpha, A, lda, X, incX, beta, Y, incY); #else - return sgemv_raw(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); + return sgemv_raw(TStorageOrder, TransA, M, N, alpha, A, lda, X, incX, beta, Y, + incY); #endif } diff --git a/nntrainer/tensor/blas_interface.h b/nntrainer/tensor/blas_interface.h index 69cdda01f9..b57ea3e057 100644 --- a/nntrainer/tensor/blas_interface.h +++ b/nntrainer/tensor/blas_interface.h @@ -16,21 +16,6 @@ #define __BLAS_INTERFACE_H_ #ifdef __cplusplus -#ifdef USE_BLAS -extern "C" { -#include -} -#else -enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 }; - -enum CBLAS_TRANSPOSE { - CblasNoTrans = 111, - CblasTrans = 112, - CblasConjTrans = 113 -}; - -#endif - #ifdef USE_CUBLAS #include #include @@ -132,7 +117,7 @@ void saxpy(const unsigned int N, const float alpha, const _FP16 *X, * @param[in] alpha float number * @param[in] beta float number */ -void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, +void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB, const unsigned int M, const unsigned int N, const unsigned int K, const float alpha, const _FP16 *A, const unsigned int lda, const _FP16 *B, const unsigned int ldb, const float beta, _FP16 *C, @@ -147,7 +132,7 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, * @param[in] alpha float number * @param[in] beta float number */ -void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, +void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M, const unsigned int N, const float alpha, const _FP16 *A, const unsigned int lda, const _FP16 *X, const int incX, const float beta, _FP16 *Y, const int incY); @@ -346,7 +331,7 @@ void saxpy(const unsigned int N, const float alpha, const float *X, * @param[in] alpha float number * @param[in] beta float number */ -void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, +void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB, const unsigned int M, const unsigned int N, const unsigned int K, const float alpha, const void *A, const unsigned int lda, const void *B, const unsigned int ldb, const float beta, void *C, @@ -363,7 +348,7 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, * @param[in] alpha float number * @param[in] beta float number */ -void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, +void sgemm(const unsigned int TStorageOrder, bool TransA, bool TransB, const unsigned int M, const unsigned int N, const unsigned int K, const float alpha, const float *A, const unsigned int lda, const float *B, const unsigned int ldb, const float beta, float *C, @@ -378,7 +363,7 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, * @param[in] alpha float number * @param[in] beta float number */ -void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, +void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M, const unsigned int N, const float alpha, const void *A, const unsigned int lda, const void *X, const int incX, const float beta, void *Y, const int incY, @@ -393,7 +378,7 @@ void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, * @param[in] alpha float number * @param[in] beta float number */ -void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M, +void sgemv(const unsigned int TStorageOrder, bool TransA, const unsigned int M, const unsigned int N, const float alpha, const float *A, const unsigned int lda, const float *X, const int incX, const float beta, float *Y, const int incY); diff --git a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp index 9e8422d404..c1ecf2ddc1 100644 --- a/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernel_interface.cpp @@ -119,8 +119,6 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, const float *data = input.getData(); const float *mdata = m.getData(); float *rdata = result.getData(); - enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; - enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans; /// shortcut handling in case of vector /// for vector, (1 * K) == (K * 1) in current memory layout... @@ -134,20 +132,19 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, } /// case2: (M * K) X (K * 1) else if (N == 1) { - transA ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context) - : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); + trans ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context) + : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); } /// case3: (1 * K) X (K * N) = 1 * N = R /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) /// Effectively a translation of sgemv else if (M == 1) { - transB = transB == CblasTrans ? CblasNoTrans : CblasTrans; - transB ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context) - : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); + trans_m ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context) + : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); } /// case others: use gemm else { - sgemm_cl(transA, transB, data, mdata, rdata, M, N, K, lda, ldb, ldc, + sgemm_cl(trans, trans_m, data, mdata, rdata, M, N, K, lda, ldb, ldc, context); } } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) { @@ -155,8 +152,6 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, const _FP16 *data = input.getData<_FP16>(); const _FP16 *mdata = m.getData<_FP16>(); _FP16 *rdata = result.getData<_FP16>(); - enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; - enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans; /// shortcut handling in case of vector /// for vector, (1 * K) == (K * 1) in current memory layout... @@ -170,20 +165,19 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, } /// case2: (M * K) X (K * 1) else if (N == 1) { - transA ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context) - : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); + trans ? sgemv_cl(data, mdata, rdata, dim2, dim1, lda, context) + : sgemv_cl(data, mdata, rdata, dim1, dim2, lda, context); } /// case3: (1 * K) X (K * N) = 1 * N = R /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) /// Effectively a translation of sgemv else if (M == 1) { - transB = transB == CblasTrans ? CblasNoTrans : CblasTrans; - transB ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context) - : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); + trans_m ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb, context) + : sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb, context); } /// case others: use sgemm else { - sgemm_cl(transA, transB, data, mdata, rdata, M, N, K, lda, ldb, ldc, + sgemm_cl(trans, trans_m, data, mdata, rdata, M, N, K, lda, ldb, ldc, context); } #else diff --git a/nntrainer/tensor/cl_operations/blas_kernels.cpp b/nntrainer/tensor/cl_operations/blas_kernels.cpp index 791cdc5e6b..5c0d1dfa72 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels.cpp @@ -282,24 +282,24 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1, return cl_ret; } -void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const float *A, - const float *B, float *C, unsigned int M, unsigned int N, - unsigned int K, unsigned int lda, unsigned int ldb, - unsigned int ldc, RunLayerContext &context) { +void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B, + float *C, unsigned int M, unsigned int N, unsigned int K, + unsigned int lda, unsigned int ldb, unsigned int ldc, + RunLayerContext &context) { opencl::Kernel *kernel_sgemm = nullptr; RunLayerContext::LayerKernel layerKernel; std::string sgemm_cl_kernel_; - if (TransA != CblasTrans && TransB != CblasTrans) { + if (!TransA && !TransB) { kernel_sgemm = &kernel_sgemm_noTrans; layerKernel = context.LayerKernel::SGEMM_NOTRANS; sgemm_cl_kernel_ = sgemm_cl_noTrans_kernel_; - } else if (TransA == CblasTrans && TransB != CblasTrans) { + } else if (TransA && !TransB) { kernel_sgemm = &kernel_sgemm_transA; layerKernel = context.LayerKernel::SGEMM_TRANSA; sgemm_cl_kernel_ = sgemm_cl_transA_kernel_; - } else if (TransA != CblasTrans && TransB == CblasTrans) { + } else if (!TransA && TransB) { kernel_sgemm = &kernel_sgemm_transB; layerKernel = context.LayerKernel::SGEMM_TRANSB; sgemm_cl_kernel_ = sgemm_cl_transB_kernel_; diff --git a/nntrainer/tensor/cl_operations/blas_kernels.h b/nntrainer/tensor/cl_operations/blas_kernels.h index 6b118c68dd..008345eef2 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels.h +++ b/nntrainer/tensor/cl_operations/blas_kernels.h @@ -61,8 +61,8 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1, /** * @brief sgemm computation : Y = op(A)*op(B) + C, * where op(X) is one of X or X**T - * @param[in] transA CBLAS_TRANSPOSE - * @param[in] transB CBLAS_TRANSPOSE + * @param[in] transA bool transpose + * @param[in] transB bool transpose * @param[in] A float * for Matrix A * @param[in] B float * for Matrix B * @param[in] C float * for Matrix C @@ -74,10 +74,10 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1, * @param[in] ldc number of C's columns * @param[in] context RunLayerContext reference */ -void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const float *A, - const float *B, float *C, unsigned int M, unsigned int N, - unsigned int K, unsigned int lda, unsigned int ldb, - unsigned int ldc, RunLayerContext &context); +void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B, + float *C, unsigned int M, unsigned int N, unsigned int K, + unsigned int lda, unsigned int ldb, unsigned int ldc, + RunLayerContext &context); /** * @brief addition : sum of all input vectors @@ -140,8 +140,8 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, /** * @brief fp16 sgemm computation : Y = op(A)*op(B) + C, * where op(X) is one of X or X**T - * @param[in] transA CBLAS_TRANSPOSE - * @param[in] transB CBLAS_TRANSPOSE + * @param[in] transA bool transpose + * @param[in] transB bool transpose * @param[in] A fp16 * for Matrix A * @param[in] B fp16 * for Matrix B * @param[in] C fp16 * for Matrix C @@ -153,10 +153,10 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, * @param[in] ldc number of C's columns * @param[in] context RunLayerContext reference */ -void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const __fp16 *A, - const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, - unsigned int K, unsigned int lda, unsigned int ldb, - unsigned int ldc, RunLayerContext &context); +void sgemm_cl(bool TransA, bool TransB, const __fp16 *A, const __fp16 *B, + __fp16 *C, unsigned int M, unsigned int N, unsigned int K, + unsigned int lda, unsigned int ldb, unsigned int ldc, + RunLayerContext &context); /** * @brief fp16 addition : sum of all input vectors diff --git a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp index 96c7ce9c90..b32e9d4a28 100644 --- a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp +++ b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp @@ -302,24 +302,24 @@ __fp16 dot_cl(const __fp16 *vecAdata, const __fp16 *vecXdata, unsigned int dim1, return cl_ret; } -void sgemm_cl(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, const __fp16 *A, - const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N, - unsigned int K, unsigned int lda, unsigned int ldb, - unsigned int ldc, RunLayerContext &context) { +void sgemm_cl(bool TransA, bool TransB, const __fp16 *A, const __fp16 *B, + __fp16 *C, unsigned int M, unsigned int N, unsigned int K, + unsigned int lda, unsigned int ldb, unsigned int ldc, + RunLayerContext &context) { opencl::Kernel *kernel_sgemm_fp16 = nullptr; RunLayerContext::LayerKernel layerKernel; std::string sgemm_cl_kernel_fp16_; - if (TransA != CblasTrans && TransB != CblasTrans) { + if (!TransA && !TransB) { kernel_sgemm_fp16 = &kernel_sgemm_noTrans_fp16; layerKernel = context.LayerKernel::SGEMM_NOTRANS_FP16; sgemm_cl_kernel_fp16_ = sgemm_cl_noTrans_kernel_fp16_; - } else if (TransA == CblasTrans && TransB != CblasTrans) { + } else if (TransA && !TransB) { kernel_sgemm_fp16 = &kernel_sgemm_transA_fp16; layerKernel = context.LayerKernel::SGEMM_TRANSA_FP16; sgemm_cl_kernel_fp16_ = sgemm_cl_transA_kernel_fp16_; - } else if (TransA != CblasTrans && TransB == CblasTrans) { + } else if (!TransA && !TransB) { kernel_sgemm_fp16 = &kernel_sgemm_transB_fp16; layerKernel = context.LayerKernel::SGEMM_TRANSB_FP16; sgemm_cl_kernel_fp16_ = sgemm_cl_transB_kernel_fp16_; diff --git a/nntrainer/tensor/float_tensor.cpp b/nntrainer/tensor/float_tensor.cpp index c925995fc9..755580b47d 100644 --- a/nntrainer/tensor/float_tensor.cpp +++ b/nntrainer/tensor/float_tensor.cpp @@ -493,8 +493,8 @@ void FloatTensor::sum_by_batch(Tensor &output) const { Tensor ones(1, 1, 1, feat_len, this->getFormat()); ones.setValue(1.0); - sgemv(CblasRowMajor, CblasNoTrans, batch, feat_len, 1, data, feat_len, - ones.getData(), 1, 0.0, out_data, 1); + sgemv((unsigned int)dim.getStorageOrder(), false, batch, feat_len, 1, data, + feat_len, ones.getData(), 1, 0.0, out_data, 1); } Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, @@ -521,8 +521,8 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, size_t batch = dim.batch(); Tensor ones(1, 1, 1, batch, getTensorType()); ones.setValue(alpha); - sgemv(CblasRowMajor, CblasTrans, batch, feat_len, 1, data, feat_len, - ones.getData(), 1, beta, output.getData(), 1); + sgemv((unsigned int)dim.getStorageOrder(), true, batch, feat_len, 1, data, + feat_len, ones.getData(), 1, beta, output.getData(), 1); } break; case 1: { CREATE_IF_EMPTY_DIMS(output, dim[0], 1, dim[2], dim[3], getTensorType()); @@ -531,8 +531,9 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, unsigned int t_axis = dim[1]; Tensor ones(1, 1, 1, t_axis, getTensorType()); ones.setValue(alpha); - sgemv(CblasRowMajor, CblasNoTrans, feat_len, t_axis, 1, data, t_axis, - ones.getData(), 1, beta, output.getData(), 1); + sgemv((unsigned int)dim.getStorageOrder(), false, feat_len, t_axis, 1, + data, t_axis, ones.getData(), 1, beta, + output.getData(), 1); } else { unsigned int feat_len = dim[2] * dim[3]; unsigned int t_axis = dim[1]; @@ -540,7 +541,7 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, ones.setValue(alpha); float *rdata = output.getData(); for (unsigned int k = 0; k < dim[0]; ++k) { - sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1, + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, feat_len, 1, &data[k * dim.getFeatureLen()], feat_len, ones.getData(), 1, beta, &rdata[k * feat_len], 1); } @@ -555,7 +556,7 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, ones.setValue(alpha); float *rdata = output.getData(); for (unsigned int k = 0; k < dim[0]; ++k) { - sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1, + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, feat_len, 1, &data[k * dim.getFeatureLen()], feat_len, ones.getData(), 1, beta, &rdata[k * feat_len], 1); } @@ -573,14 +574,15 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, unsigned int ridx = k * output.getDim().getFeatureLen() + c * dim[3]; - sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3, - ones.getData(), 1, beta, &rdata[ridx], 1); + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, t_3, 1, + &data[idx], t_3, ones.getData(), 1, beta, &rdata[ridx], + 1); } } } else { - sgemv(CblasColMajor, CblasTrans, t_axis, output.getDim().getDataLen(), - 1, data, t_axis, ones.getData(), 1, beta, - output.getData(), 1); + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, + output.getDim().getDataLen(), 1, data, t_axis, + ones.getData(), 1, beta, output.getData(), 1); } } } break; @@ -597,8 +599,9 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, for (unsigned int c = 0; c < dim[2]; ++c) { unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[1]; unsigned int ridx = k * output.getDim().getFeatureLen() + c * dim[1]; - sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3, - ones.getData(), 1, beta, &rdata[ridx], 1); + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, t_3, 1, + &data[idx], t_3, ones.getData(), 1, beta, &rdata[ridx], + 1); } } } else { @@ -608,7 +611,7 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, ones.setValue(alpha); if (dim.getStorageOrder() == TStorageOrder::ROW_MAJOR) { - sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n, + sgemv((unsigned int)dim.getStorageOrder(), false, m, n, 1, data, n, ones.getData(), 1, beta, output.getData(), 1); } else { float *rdata = output.getData(); @@ -618,8 +621,9 @@ Tensor &FloatTensor::sum(unsigned int axis, Tensor &output, float alpha, unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[2]; unsigned int ridx = k * dim[1] * dim[2] + c * dim[2]; - sgemv(CblasColMajor, CblasNoTrans, dim[2], n, 1, &data[idx], dim[2], - ones.getData(), 1, beta, &rdata[ridx], 1); + sgemv((unsigned int)dim.getStorageOrder(), false, dim[2], n, 1, + &data[idx], dim[2], ones.getData(), 1, beta, + &rdata[ridx], 1); } } } @@ -699,8 +703,6 @@ Tensor &FloatTensor::dot(Tensor const &input, Tensor &output, bool trans, const float *mdata = input.getData(); float *rdata = output.getData(); const float alpha = 1.0f; - enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; - enum CBLAS_TRANSPOSE transB = trans_in ? CblasTrans : CblasNoTrans; /// shortcut handling in case of vector /// for vector, (1 * K) == (K * 1) in current memory layout... @@ -714,21 +716,20 @@ Tensor &FloatTensor::dot(Tensor const &input, Tensor &output, bool trans, } /// case2: (M * K) X (K * 1) else if (N == 1) { - sgemv(CblasRowMajor, transA, first_three_flat, last_axis, alpha, data, lda, - mdata, 1, beta, rdata, 1); + sgemv((unsigned int)dim.getStorageOrder(), trans, first_three_flat, + last_axis, alpha, data, lda, mdata, 1, beta, rdata, 1); } /// case3: (1 * K) X (K * N) = 1 * N = R /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) /// Effectively a translation of sgemv else if (M == 1) { - transB = transB == CblasTrans ? CblasNoTrans : CblasTrans; - sgemv(CblasRowMajor, transB, input_first_three_flat, input_last_axis, alpha, - mdata, ldb, data, 1, beta, rdata, 1); + sgemv((unsigned int)dim.getStorageOrder(), trans_in, input_first_three_flat, + input_last_axis, alpha, mdata, ldb, data, 1, beta, rdata, 1); } /// case others: use gemm else { - sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata, ldb, - beta, rdata, ldc); + sgemm((unsigned int)dim.getStorageOrder(), trans, trans_in, M, N, K, alpha, + data, lda, mdata, ldb, beta, rdata, ldc); } return output; diff --git a/nntrainer/tensor/half_tensor.cpp b/nntrainer/tensor/half_tensor.cpp index aa43dda048..14a9e7e0c2 100644 --- a/nntrainer/tensor/half_tensor.cpp +++ b/nntrainer/tensor/half_tensor.cpp @@ -478,8 +478,8 @@ void HalfTensor::sum_by_batch(Tensor &output) const { Tensor ones(1, 1, 1, feat_len, this->getTensorType()); ones.setValue((_FP16)1.0); - sgemv(CblasRowMajor, CblasNoTrans, batch, feat_len, 1, data, feat_len, - ones.getData<_FP16>(), 1, 0.0, out_data, 1); + sgemv((unsigned int)dim.getStorageOrder(), false, batch, feat_len, 1, data, + feat_len, ones.getData<_FP16>(), 1, 0.0, out_data, 1); } Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha, @@ -507,8 +507,8 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha, size_t batch = dim.batch(); Tensor ones(1, 1, 1, batch, this->getTensorType()); ones.setValue(alpha); - sgemv(CblasRowMajor, CblasTrans, batch, feat_len, 1, data, feat_len, - ones.getData<_FP16>(), 1, beta, output.getData<_FP16>(), 1); + sgemv((unsigned int)dim.getStorageOrder(), true, batch, feat_len, 1, data, + feat_len, ones.getData<_FP16>(), 1, beta, output.getData<_FP16>(), 1); } break; case 1: { CREATE_IF_EMPTY_DIMS(output, dim[0], 1, dim[2], dim[3], getTensorType()); @@ -517,8 +517,9 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha, unsigned int t_axis = dim[1]; Tensor ones(1, 1, 1, t_axis, this->getTensorType()); ones.setValue(alpha); - sgemv(CblasRowMajor, CblasNoTrans, feat_len, t_axis, 1, data, t_axis, - ones.getData<_FP16>(), 1, beta, output.getData<_FP16>(), 1); + sgemv((unsigned int)dim.getStorageOrder(), false, feat_len, t_axis, 1, + data, t_axis, ones.getData<_FP16>(), 1, beta, + output.getData<_FP16>(), 1); } else { unsigned int feat_len = dim[2] * dim[3]; unsigned int t_axis = dim[1]; @@ -526,7 +527,7 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha, ones.setValue(alpha); _FP16 *rdata = output.getData<_FP16>(); for (unsigned int k = 0; k < dim[0]; ++k) { - sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1, + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, feat_len, 1, &data[k * dim.getFeatureLen()], feat_len, ones.getData<_FP16>(), 1, beta, &rdata[k * feat_len], 1); } @@ -542,7 +543,7 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha, ones.setValue(alpha); _FP16 *rdata = output.getData<_FP16>(); for (unsigned int k = 0; k < dim[0]; ++k) { - sgemv(CblasRowMajor, CblasTrans, t_axis, feat_len, 1, + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, feat_len, 1, &data[k * dim.getFeatureLen()], feat_len, ones.getData<_FP16>(), 1, beta, &rdata[k * feat_len], 1); } @@ -556,8 +557,9 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha, for (unsigned int c = 0; c < dim[1]; ++c) { unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[2]; unsigned int ridx = k * output.getDim().getFeatureLen() + c * dim[3]; - sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3, - ones.getData<_FP16>(), 1, beta, &rdata[ridx], 1); + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, t_3, 1, + &data[idx], t_3, ones.getData<_FP16>(), 1, beta, &rdata[ridx], + 1); } } } @@ -574,8 +576,9 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha, for (unsigned int c = 0; c < dim[2]; ++c) { unsigned int idx = k * dim.getFeatureLen() + c * dim[3] * dim[1]; unsigned int ridx = k * output.getDim().getFeatureLen() + c * dim[1]; - sgemv(CblasRowMajor, CblasTrans, t_axis, t_3, 1, &data[idx], t_3, - ones.getData<_FP16>(), 1, beta, &rdata[ridx], 1); + sgemv((unsigned int)dim.getStorageOrder(), true, t_axis, t_3, 1, + &data[idx], t_3, ones.getData<_FP16>(), 1, beta, &rdata[ridx], + 1); } } } else { @@ -583,7 +586,7 @@ Tensor &HalfTensor::sum(unsigned int axis, Tensor &output, float alpha, unsigned int n = dim[3]; Tensor ones(1, 1, 1, n, getTensorType()); ones.setValue(alpha); - sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n, + sgemv((unsigned int)dim.getStorageOrder(), false, m, n, 1, data, n, ones.getData<_FP16>(), 1, beta, output.getData<_FP16>(), 1); } } break; @@ -651,8 +654,6 @@ Tensor &HalfTensor::dot(Tensor const &input, Tensor &output, bool trans, const _FP16 *mdata = input.getData<_FP16>(); _FP16 *rdata = output.getData<_FP16>(); const float alpha = 1.0f; - enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans; - enum CBLAS_TRANSPOSE transB = trans_in ? CblasTrans : CblasNoTrans; /// shortcut handling in case of vector /// for vector, (1 * K) == (K * 1) in current memory layout... @@ -666,21 +667,20 @@ Tensor &HalfTensor::dot(Tensor const &input, Tensor &output, bool trans, } /// case2: (M * K) X (K * 1) else if (N == 1) { - sgemv(CblasRowMajor, transA, first_three_flat, last_axis, alpha, data, lda, - mdata, 1, beta, rdata, 1); + sgemv((unsigned int)dim.getStorageOrder(), trans, first_three_flat, + last_axis, alpha, data, lda, mdata, 1, beta, rdata, 1); } /// case3: (1 * K) X (K * N) = 1 * N = R /// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K) /// Effectively a translation of sgemv else if (M == 1) { - transB = transB == CblasTrans ? CblasNoTrans : CblasTrans; - sgemv(CblasRowMajor, transB, input_first_three_flat, input_last_axis, alpha, - mdata, ldb, data, 1, beta, rdata, 1); + sgemv((unsigned int)dim.getStorageOrder(), trans_in, input_first_three_flat, + input_last_axis, alpha, mdata, ldb, data, 1, beta, rdata, 1); } /// case others: use sgemm else { - sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata, ldb, - beta, rdata, ldc); + sgemm((unsigned int)dim.getStorageOrder(), trans, trans_in, M, N, K, alpha, + data, lda, mdata, ldb, beta, rdata, ldc); } return output;