Skip to content

Commit

Permalink
[bugfix] Fix sgemv_cl function call from blas_kernel_interface
Browse files Browse the repository at this point in the history
Fixed sgemv_cl function call. Failing unittest after recent changes.

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
  • Loading branch information
s-debadri committed Oct 14, 2024
1 parent f222ecf commit 6c9fcf3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nntrainer/tensor/cl_operations/blas_kernel_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, bool trans,
/// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
/// Effectively a translation of sgemv
else if (M == 1) {
trans_m ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb)
: sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb);
trans_m ? sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb)
: sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb);
}
/// case others: use gemm
else {
Expand Down Expand Up @@ -170,8 +170,8 @@ void dotCl(Tensor const &input, Tensor const &m, Tensor &result, bool trans,
/// = R^T = (K * N) ^T * (1 * K) ^T = (N * K) * (K * 1) = (N * K) * (1 * K)
/// Effectively a translation of sgemv
else if (M == 1) {
trans_m ? sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb)
: sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb);
trans_m ? sgemv_cl(mdata, data, rdata, mdim1, mdim2, ldb)
: sgemv_cl(mdata, data, rdata, mdim2, mdim1, ldb);
}
/// case others: use sgemm
else {
Expand Down

0 comments on commit 6c9fcf3

Please sign in to comment.