diff --git a/src/common/dnnl_traits.hpp b/src/common/dnnl_traits.hpp index bcebf6fcb01..d4e70f960a6 100644 --- a/src/common/dnnl_traits.hpp +++ b/src/common/dnnl_traits.hpp @@ -174,6 +174,7 @@ PKIND_TRAITS_INST(binary); PKIND_TRAITS_INST(matmul); PKIND_TRAITS_INST(resampling); PKIND_TRAITS_INST(reduction); +PKIND_TRAITS_INST(sum); PKIND_TRAITS_INST(sdpa); #undef PKIND_TRAITS_INST diff --git a/src/gpu/generic/sycl/ref_sum.cpp b/src/gpu/generic/sycl/ref_sum.cpp new file mode 100644 index 00000000000..de884246fe1 --- /dev/null +++ b/src/gpu/generic/sycl/ref_sum.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* +* Copyright 2022-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "gpu/generic/sycl/ref_sum.hpp" +#include "gpu/generic/sycl/sum_kernels.hpp" +#include "gpu/generic/sycl/sycl_gpu_primitive.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +status_t ref_sum_t::pd_t::init_conf() { + conf_ = sycl_sum_conf_t(); + conf_.n = n_inputs(); + + for (auto i = 0; i < conf_.n; ++i) { + conf_.src_md[i] = xpu::sycl::md_t(src_md(i)); + conf_.src_scales[i] = scales()[i]; + } + conf_.dst_md = xpu::sycl::md_t(dst_md()); + + // XXX: should probably be tuned. + conf_.block_size = 16; + conf_.wg_size = 32; + conf_.wk_size = memory_desc_wrapper(dst_md()).nelems(); + return status::success; +} + +status_t ref_sum_t::init(engine_t *engine) { + const auto kid = ::sycl::get_kernel_id(); + CHECK(create_kernel(engine, kid, &kernel_)); + + return status::success; +} + +status_t ref_sum_t::execute(const exec_ctx_t &ctx) const { + using namespace memory_tracking::names; + + parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) { + auto src0_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 0); + auto src1_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 1); + auto src2_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 2); + auto src3_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 3); + auto src4_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 4); + auto src5_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 5); + auto src6_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 6); + auto src7_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 7); + auto src8_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 8); + auto src9_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 9); + auto src10_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 10); + auto src11_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 11); + auto src12_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 12); + auto src13_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 13); + auto src14_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 14); + auto src15_mem_arg + = CTX_IN_SYCL_KERNEL_MEMORY(DNNL_ARG_MULTIPLE_SRC + 15); + + auto dst_mem_arg = CTX_OUT_SYCL_KERNEL_MEMORY(DNNL_ARG_DST); + + sum_kernel_vec_t sum_kernel(pd()->conf_, src0_mem_arg, src1_mem_arg, + src2_mem_arg, src3_mem_arg, src4_mem_arg, src5_mem_arg, + src6_mem_arg, src7_mem_arg, src8_mem_arg, src9_mem_arg, + src10_mem_arg, src11_mem_arg, src12_mem_arg, src13_mem_arg, + src14_mem_arg, src15_mem_arg, dst_mem_arg); + + const int block_size = pd()->conf_.block_size; + const int wg_size = pd()->conf_.wg_size; + + const int t_work = pd()->conf_.wk_size; + const int wg_work = wg_size * block_size; + const int wg_cnt = utils::div_up(t_work, wg_work); + + cgh.parallel_for( + ::sycl::nd_range<1>(wg_cnt * wg_size, wg_size), sum_kernel); + }); + + return status::success; +} + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl diff --git a/src/gpu/generic/sycl/ref_sum.hpp b/src/gpu/generic/sycl/ref_sum.hpp new file mode 100644 index 00000000000..68fb3fab4fc --- /dev/null +++ b/src/gpu/generic/sycl/ref_sum.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright 2022-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_SYCL_REF_SUM_HPP +#define GPU_SYCL_REF_SUM_HPP + +#include "common/primitive.hpp" +#include "common/stream.hpp" +#include "gpu/generic/sycl/sycl_gpu_primitive.hpp" +#include "gpu/generic/sycl/sycl_io_helper.hpp" +#include "gpu/generic/sycl/sycl_post_ops.hpp" +#include "gpu/generic/sycl/sycl_primitive_conf.hpp" +#include "gpu/generic/sycl/sycl_q10n.hpp" +#include "gpu/gpu_sum_pd.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +struct ref_sum_t : public gpu::generic::sycl::primitive_t { + using gpu::generic::sycl::primitive_t::primitive_t; + + struct pd_t : public gpu_sum_pd_t { + using gpu_sum_pd_t::gpu_sum_pd_t; + + DECLARE_SUM_PD_T("dpcpp:ref:any", ref_sum_t); + + status_t init(impl::engine_t *engine) { + using namespace data_type; + using namespace format_tag; + + const memory_desc_wrapper dst_d(dst_md()); + if (!utils::one_of(dst_d.data_type(), f32, bf16, f16, s8, u8)) + return status::unimplemented; + // Block formats are not yet supported + // Dimensions can not be > 6 + if (!dst_d.is_plain() || dst_d.ndims() > xpu::sycl::md_t::max_dims) + return status::unimplemented; + + const int n = n_inputs(); + for (auto i = 0; i < n; ++i) { + const memory_desc_wrapper src_d(src_md(i)); + if (!utils::one_of(src_d.data_type(), f32, bf16, f16, s8, u8)) + return status::unimplemented; + // Block formats are not yet supported + // Dimensions can not be > 6 + if (!src_d.is_plain() || src_d.ndims() > xpu::sycl::md_t::max_dims) + return status::unimplemented; + } + + const bool ok = set_default_params() == status::success + && n <= DNNL_REF_SUM_MAX_NUM_TENSORS; + if (!ok) return status::unimplemented; + + return init_conf(); + } + + sycl_sum_conf_t conf_; + + private: + status_t init_conf(); + + inline bool equal(float in_value, float in_compare_to) { + return std::fabs(in_value - in_compare_to) <= FLT_EPSILON; + } + }; + + status_t init(impl::engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + kernel_t kernel_; +}; + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/generic/sycl/ref_sum_many_inputs.cpp b/src/gpu/generic/sycl/ref_sum_many_inputs.cpp new file mode 100644 index 00000000000..0ef2e41cf5c --- /dev/null +++ b/src/gpu/generic/sycl/ref_sum_many_inputs.cpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2022-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "gpu/generic/sycl/ref_sum_many_inputs.hpp" +#include "common/primitive_desc_iface.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +status_t ref_sum_many_inputs_t::pd_t::init_conf() { + conf_ = sycl_sum_conf_t(); + conf_.n = n_inputs(); + + return status::success; +} + +status_t ref_sum_many_inputs_t::init(engine_t *engine) { + const size_t n = pd()->base_pds_.size(); + base_prims_.resize(n); + for (size_t i = 0; i < n; ++i) { + CHECK(pd()->base_pds_[i]->impl()->create_primitive( + base_prims_[i], engine, cache_blob())); + } + + return status::success; +} + +status_t ref_sum_many_inputs_t::execute(const exec_ctx_t &ctx) const { + memory_arg_t dst_mem_arg = {ctx.args().at(DNNL_ARG_DST).mem, false}; + memory_arg_t dst_read_mem_arg = {ctx.args().at(DNNL_ARG_DST).mem, true}; + + int n_remaining = pd()->conf_.n; + int in_arg_offset = 0; + int i = 0; + + while (n_remaining > 0) { + bool pass_in_dst = i > 0; + int max_n_child_inputs = DNNL_REF_SUM_MAX_NUM_TENSORS - pass_in_dst; + int args_handled = std::min(n_remaining, max_n_child_inputs); + exec_args_t r_args; + r_args[DNNL_ARG_DST] = dst_mem_arg; + if (pass_in_dst) { + r_args[DNNL_ARG_MULTIPLE_SRC + 0] = dst_read_mem_arg; + } + for (int j = 0; j < args_handled; j++) { + r_args[DNNL_ARG_MULTIPLE_SRC + j + pass_in_dst] + = ctx.args().at(DNNL_ARG_MULTIPLE_SRC + j + in_arg_offset); + } + n_remaining -= args_handled; + in_arg_offset += args_handled; + i++; + + exec_ctx_t r_ctx(ctx, std::move(r_args)); + CHECK(base_prims_[i]->execute(r_ctx)); + } + return status::success; +} + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl diff --git a/src/gpu/generic/sycl/ref_sum_many_inputs.hpp b/src/gpu/generic/sycl/ref_sum_many_inputs.hpp new file mode 100644 index 00000000000..72530fc9cfd --- /dev/null +++ b/src/gpu/generic/sycl/ref_sum_many_inputs.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2022-2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_SYCL_REF_SUM_MANY_INPUTS_HPP +#define GPU_SYCL_REF_SUM_MANY_INPUTS_HPP + +#include "common/primitive.hpp" +#include "gpu/generic/sycl/sycl_gpu_primitive.hpp" +#include "gpu/generic/sycl/sycl_io_helper.hpp" +#include "gpu/generic/sycl/sycl_post_ops.hpp" +#include "gpu/generic/sycl/sycl_primitive_conf.hpp" +#include "gpu/generic/sycl/sycl_q10n.hpp" +#include "gpu/gpu_sum_pd.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +struct ref_sum_many_inputs_t : public gpu::generic::sycl::primitive_t { + using gpu::generic::sycl::primitive_t::primitive_t; + + struct pd_t : public gpu_sum_pd_t { + using gpu_sum_pd_t::gpu_sum_pd_t; + + DECLARE_SUM_PD_t("dpcpp:ref:sum_many_inputs", ref_sum_many_inputs_t); + + status_t init(impl::engine_t *engine) { + using namespace data_type; + + const memory_desc_wrapper dst_d(dst_md()); + + const int n = n_inputs(); + const bool ok = set_default_params() == status::success + && attr()->has_default_values() + && n > DNNL_REF_SUM_MAX_NUM_TENSORS; // prevent inf recursion + if (!ok) return status::unimplemented; + + // the first kernel handles up to 16 inputs and remaining ones up to 15 + const int n_kernels = n == 1 + ? 1 + : utils::div_up(n - 1, DNNL_REF_SUM_MAX_NUM_TENSORS - 1); + base_pds_.resize(n_kernels); + int in_arg_offset = 0; + int n_remaining = n; + for (auto i = 0; i < n_kernels; ++i) { + bool pass_in_dst = i > 0; + int max_n_child_inputs + = DNNL_REF_SUM_MAX_NUM_TENSORS - pass_in_dst; + int n_child_inputs = std::min(n_remaining, max_n_child_inputs); + const memory_desc_t *src[DNNL_REF_SUM_MAX_NUM_TENSORS]; + if (pass_in_dst) { src[0] = dst_md(); } + for (int j = 0; j < n_child_inputs; j++) { + src[j + pass_in_dst] = src_md(j + in_arg_offset); + } + in_arg_offset += n_child_inputs; + n_remaining -= n_child_inputs; + + primitive_attr_t r_attr; + CHECK(dnnl_sum_primitive_desc_create(&base_pds_[i], engine, + dst_md(), n_child_inputs + pass_in_dst, scales(), src, + &r_attr)); + } + + return init_conf(); + } + + sycl_sum_conf_t conf_; + std::vector base_pds_; + + private: + status_t init_conf(); + }; + + status_t init(impl::engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::vector> base_prims_; +}; + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/generic/sycl/sum_kernels.hpp b/src/gpu/generic/sycl/sum_kernels.hpp new file mode 100644 index 00000000000..61e1c862533 --- /dev/null +++ b/src/gpu/generic/sycl/sum_kernels.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2022-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GPU_SYCL_SUM_KERNELS_HPP +#define GPU_SYCL_SUM_KERNELS_HPP + +#include "gpu/generic/sycl/sycl_io_helper.hpp" +#include "gpu/generic/sycl/sycl_post_ops.hpp" +#include "gpu/generic/sycl/sycl_primitive_conf.hpp" +#include "gpu/generic/sycl/sycl_q10n.hpp" + +namespace dnnl { +namespace impl { +namespace gpu { +namespace generic { +namespace sycl { + +#define DNNL_ARG_SRC_4 5 +#define DNNL_ARG_SRC_5 6 +#define DNNL_ARG_SRC_6 7 +#define DNNL_ARG_SRC_7 8 +#define DNNL_ARG_SRC_8 9 +#define DNNL_ARG_SRC_9 10 +#define DNNL_ARG_SRC_10 11 +#define DNNL_ARG_SRC_11 12 +#define DNNL_ARG_SRC_12 13 +#define DNNL_ARG_SRC_13 14 +#define DNNL_ARG_SRC_14 15 +#define DNNL_ARG_SRC_15 16 + +struct sum_kernel_vec_t { + static constexpr int vec_len = 8; + static constexpr int max_supported_ndims = 6; + + sum_kernel_vec_t(const sycl_sum_conf_t &conf, + xpu::sycl::in_memory_arg_t &src0, xpu::sycl::in_memory_arg_t &src1, + xpu::sycl::in_memory_arg_t &src2, xpu::sycl::in_memory_arg_t &src3, + xpu::sycl::in_memory_arg_t &src4, xpu::sycl::in_memory_arg_t &src5, + xpu::sycl::in_memory_arg_t &src6, xpu::sycl::in_memory_arg_t &src7, + xpu::sycl::in_memory_arg_t &src8, xpu::sycl::in_memory_arg_t &src9, + xpu::sycl::in_memory_arg_t &src10, + xpu::sycl::in_memory_arg_t &src11, + xpu::sycl::in_memory_arg_t &src12, + xpu::sycl::in_memory_arg_t &src13, + xpu::sycl::in_memory_arg_t &src14, + xpu::sycl::in_memory_arg_t &src15, xpu::sycl::out_memory_arg_t &dst) + : conf_(conf) + , src0_(src0) + , src1_(src1) + , src2_(src2) + , src3_(src3) + , src4_(src4) + , src5_(src5) + , src6_(src6) + , src7_(src7) + , src8_(src8) + , src9_(src9) + , src10_(src10) + , src11_(src11) + , src12_(src12) + , src13_(src13) + , src14_(src14) + , src15_(src15) + , dst_(dst) {} + + void operator()(::sycl::nd_item<1> item) const { + auto sg = item.get_sub_group(); + size_t wg_offset_t = item.get_group(0) * conf_.wg_size; + size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0]; + size_t wi_offset_t = sg.get_local_id(); + size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t; + + size_t base_idx = offset_t * conf_.block_size; + + dims_t dims, off, strides; + for (int i = 0; i < max_supported_ndims; i++) { + dims[i] = (i < conf_.dst_md.ndims()) ? conf_.dst_md.dims()[i] : 1; + strides[i] = (i < conf_.dst_md.ndims()) ? conf_.dst_md.strides()[i] + : INT_MAX; + } + + for (int i = 0; i < conf_.block_size; i++) { + int idx = base_idx + i; + if (idx < conf_.wk_size) { + for (int i = 0; i < max_supported_ndims; i++) { + off[i] = idx / strides[i] % dims[i]; + } + auto result = load_float_val(src0_ptr(), conf_.src_md[0], off); + +#define ONEDNN_SYCL_SUM_ADD_ARG(ARG_N) \ + if (conf_.n > ARG_N) \ + result += conf_.src_scales[ARG_N] \ + * load_float_val( \ + src##ARG_N##_ptr(), conf_.src_md[ARG_N], off); + + ONEDNN_SYCL_SUM_ADD_ARG(1) + ONEDNN_SYCL_SUM_ADD_ARG(2) + ONEDNN_SYCL_SUM_ADD_ARG(3) + ONEDNN_SYCL_SUM_ADD_ARG(4) + ONEDNN_SYCL_SUM_ADD_ARG(5) + ONEDNN_SYCL_SUM_ADD_ARG(6) + ONEDNN_SYCL_SUM_ADD_ARG(7) + ONEDNN_SYCL_SUM_ADD_ARG(8) + ONEDNN_SYCL_SUM_ADD_ARG(9) + ONEDNN_SYCL_SUM_ADD_ARG(11) + ONEDNN_SYCL_SUM_ADD_ARG(11) + ONEDNN_SYCL_SUM_ADD_ARG(12) + ONEDNN_SYCL_SUM_ADD_ARG(13) + ONEDNN_SYCL_SUM_ADD_ARG(14) + ONEDNN_SYCL_SUM_ADD_ARG(15) +#undef ONEDNN_SYCL_SUM_ADD_ARG + + store_float_value( + conf_.dst_md.data_type(), result, dst_ptr(), idx); + } + } + } + +private: + float load_float_val( + const void *ptr, xpu::sycl::md_t md, dims_t off) const { + return load_float_value(md.data_type(), ptr, md.off_v(off)); + } + + void *src0_ptr() const { return src0_.get_pointer(); } + void *src1_ptr() const { return src1_.get_pointer(); } + void *src2_ptr() const { return src2_.get_pointer(); } + void *src3_ptr() const { return src3_.get_pointer(); } + void *src4_ptr() const { return src4_.get_pointer(); } + void *src5_ptr() const { return src5_.get_pointer(); } + void *src6_ptr() const { return src6_.get_pointer(); } + void *src7_ptr() const { return src7_.get_pointer(); } + void *src8_ptr() const { return src8_.get_pointer(); } + void *src9_ptr() const { return src9_.get_pointer(); } + void *src10_ptr() const { return src10_.get_pointer(); } + void *src11_ptr() const { return src11_.get_pointer(); } + void *src12_ptr() const { return src12_.get_pointer(); } + void *src13_ptr() const { return src13_.get_pointer(); } + void *src14_ptr() const { return src14_.get_pointer(); } + void *src15_ptr() const { return src15_.get_pointer(); } + void *dst_ptr() const { return dst_.get_pointer(); } + + sycl_sum_conf_t conf_; + + xpu::sycl::in_memory_arg_t src0_; + xpu::sycl::in_memory_arg_t src1_; + xpu::sycl::in_memory_arg_t src2_; + xpu::sycl::in_memory_arg_t src3_; + xpu::sycl::in_memory_arg_t src4_; + xpu::sycl::in_memory_arg_t src5_; + xpu::sycl::in_memory_arg_t src6_; + xpu::sycl::in_memory_arg_t src7_; + xpu::sycl::in_memory_arg_t src8_; + xpu::sycl::in_memory_arg_t src9_; + xpu::sycl::in_memory_arg_t src10_; + xpu::sycl::in_memory_arg_t src11_; + xpu::sycl::in_memory_arg_t src12_; + xpu::sycl::in_memory_arg_t src13_; + xpu::sycl::in_memory_arg_t src14_; + xpu::sycl::in_memory_arg_t src15_; + xpu::sycl::out_memory_arg_t dst_; +}; + +} // namespace sycl +} // namespace generic +} // namespace gpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/gpu/generic/sycl/sycl_primitive_conf.hpp b/src/gpu/generic/sycl/sycl_primitive_conf.hpp index 6cd97d1cdc7..5b4c2e0c8ce 100644 --- a/src/gpu/generic/sycl/sycl_primitive_conf.hpp +++ b/src/gpu/generic/sycl/sycl_primitive_conf.hpp @@ -378,6 +378,19 @@ struct sycl_pooling_fwd_conf_t : public sycl_pooling_base_conf_t { sycl_post_ops_t post_ops; }; +#define DNNL_REF_SUM_MAX_NUM_TENSORS 16 + +struct sycl_sum_conf_t { + xpu::sycl::md_t src_md[DNNL_REF_SUM_MAX_NUM_TENSORS]; + xpu::sycl::md_t dst_md; + float src_scales[DNNL_REF_SUM_MAX_NUM_TENSORS]; + int n; + int block_size; + int wg_size; + int wk_size; + +}; + struct sycl_pooling_bwd_conf_t : public sycl_pooling_base_conf_t { xpu::sycl::md_t diff_src_md; xpu::sycl::md_t diff_dst_md; @@ -392,6 +405,7 @@ CHECK_SYCL_KERNEL_ARG_TYPE(sycl_softmax_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_layer_normalization_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_eltwise_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_lrn_conf_t); +CHECK_SYCL_KERNEL_ARG_TYPE(sycl_sum_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_base_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_fwd_conf_t); CHECK_SYCL_KERNEL_ARG_TYPE(sycl_pooling_bwd_conf_t); diff --git a/src/gpu/gpu_sum_list.cpp b/src/gpu/gpu_sum_list.cpp index acf21625e7c..1703a6b3909 100644 --- a/src/gpu/gpu_sum_list.cpp +++ b/src/gpu/gpu_sum_list.cpp @@ -30,6 +30,8 @@ #endif #if DNNL_GPU_VENDOR == DNNL_VENDOR_NVIDIA +#include "gpu/generic/sycl/ref_sum.hpp" +#include "gpu/generic/sycl/ref_sum_many_inputs.hpp" #include "gpu/nvidia/cudnn_sum.hpp" #endif @@ -49,6 +51,8 @@ constexpr impl_list_item_t impl_list[] = REG_SUM_P({ GPU_SUM_INSTANCE_INTEL(intel::ocl::many_inputs_sum_t) GPU_SUM_INSTANCE_INTEL(intel::ocl::simple_sum_t) GPU_SUM_INSTANCE_NVIDIA(nvidia::cudnn_ref_sum_t) + GPU_SUM_INSTANCE_GENERIC_SYCL(generic::sycl::ref_sum_t) + GPU_SUM_INSTANCE_GENERIC_SYCL(generic::sycl::ref_sum_many_inputs_t) GPU_SUM_INSTANCE_GENERIC(generic::ref_sum_t) nullptr, });