Skip to content

Commit

Permalink
Implements overloads for binary element-wise functions between a broa…
Browse files Browse the repository at this point in the history
…dcast scalar and an array

Implements these new kernels for addition
  • Loading branch information
ndgrigorian committed Aug 27, 2024
1 parent dd5f585 commit c3f411b
Show file tree
Hide file tree
Showing 32 changed files with 1,349 additions and 54 deletions.
214 changes: 214 additions & 0 deletions dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,42 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
tmp);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
const argT2 &in2) const
{
auto tmp = in1 + in2;
if constexpr (std::is_same_v<resT,
typename decltype(tmp)::element_type>)
{
return tmp;
}
else {
using dpctl::tensor::type_utils::vec_cast;

return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
tmp);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz>
operator()(const argT1 &in1, const sycl::vec<argT2, vec_sz> &in2) const
{
auto tmp = in1 + in2;
if constexpr (std::is_same_v<resT,
typename decltype(tmp)::element_type>)
{
return tmp;
}
else {
using dpctl::tensor::type_utils::vec_cast;

return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
tmp);
}
}
};

template <typename argT1,
Expand Down Expand Up @@ -393,6 +429,126 @@ struct AddContigRowContigMatrixBroadcastFactory
}
};

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2,
bool enable_sg_loadstore = true>
using AddScalarContigArrayFunctor =
elementwise_common::BinaryScalarContigArrayFunctor<
argT1,
argT2,
resT,
AddFunctor<argT1, argT2, resT>,
vec_sz,
n_vecs,
enable_sg_loadstore>;

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2,
bool enable_sg_loadstore = true>
using AddContigArrayScalarFunctor =
elementwise_common::BinaryContigArrayScalarFunctor<
argT1,
argT2,
resT,
AddFunctor<argT1, argT2, resT>,
vec_sz,
n_vecs,
enable_sg_loadstore>;

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class add_scalar_contig_array_kernel;

template <typename argTy1, typename argTy2>
sycl::event
add_scalar_contig_array_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg1_p,
ssize_t arg1_offset,
const char *arg2_p,
ssize_t arg2_offset,
char *res_p,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_scalar_contig_array_impl<
argTy1, argTy2, AddOutputType, AddScalarContigArrayFunctor,
add_scalar_contig_array_kernel>(exec_q, nelems, arg1_p, arg1_offset,
arg2_p, arg2_offset, res_p, res_offset,
depends);
}

template <typename argT1,
typename argT2,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class add_contig_array_scalar_kernel;

template <typename argTy1, typename argTy2>
sycl::event
add_contig_array_scalar_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg1_p,
ssize_t arg1_offset,
const char *arg2_p,
ssize_t arg2_offset,
char *res_p,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_contig_array_scalar_impl<
argTy1, argTy2, AddOutputType, AddContigArrayScalarFunctor,
add_contig_array_scalar_kernel>(exec_q, nelems, arg1_p, arg1_offset,
arg2_p, arg2_offset, res_p, res_offset,
depends);
}

template <typename fnT, typename T1, typename T2>
struct AddScalarContigArrayFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = add_scalar_contig_array_impl<T1, T2>;
return fn;
}
}
};

template <typename fnT, typename T1, typename T2>
struct AddContigArrayScalarFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = add_contig_array_scalar_impl<T1, T2>;
return fn;
}
}
};

template <typename argT, typename resT> struct AddInplaceFunctor
{

Expand All @@ -409,6 +565,12 @@ template <typename argT, typename resT> struct AddInplaceFunctor
{
res += in;
}

template <int vec_sz>
void operator()(sycl::vec<resT, vec_sz> &res, const argT &in)
{
res += in;
}
};

template <typename argT,
Expand Down Expand Up @@ -606,6 +768,58 @@ struct AddInplaceRowMatrixBroadcastFactory
}
};

template <typename argT,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2,
bool enable_sg_loadstore = true>
using AddInplaceScalarContigFunctor =
elementwise_common::BinaryInplaceScalarContigFunctor<
argT,
resT,
AddInplaceFunctor<argT, resT>,
vec_sz,
n_vecs,
enable_sg_loadstore>;

template <typename argT,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class add_inplace_scalar_contig_kernel;

template <typename argTy, typename resTy>
sycl::event
add_inplace_scalar_contig_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg_p,
ssize_t arg_offset,
char *res_p,
ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_inplace_scalar_contig_impl<
argTy, resTy, AddInplaceScalarContigFunctor,
add_inplace_scalar_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
res_p, res_offset, depends);
}

template <typename fnT, typename T1, typename T2>
struct AddInplaceScalarContigFactory
{
fnT get()
{
if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
fnT fn = nullptr;
return fn;
}
else {
fnT fn = add_inplace_scalar_contig_impl<T1, T2>;
return fn;
}
}
};

} // namespace add
} // namespace kernels
} // namespace tensor
Expand Down
Loading

0 comments on commit c3f411b

Please sign in to comment.