From 537896a42e7ccfef19c089fee747dd6389791cdd Mon Sep 17 00:00:00 2001 From: aamijar Date: Wed, 12 Jun 2024 17:21:35 +0000 Subject: [PATCH 01/11] dice-distance-dense-inputs --- cpp/CMakeLists.txt | 2 + cpp/include/raft/distance/detail/distance.cuh | 74 +++++++++++++++- .../distance/detail/distance_ops/all_ops.cuh | 3 +- .../distance/detail/distance_ops/dice.cuh | 85 +++++++++++++++++++ .../detail/pairwise_matrix/dispatch-ext.cuh | 6 +- cpp/include/raft/distance/distance-ext.cuh | 48 +++++++++++ cpp/include/raft/distance/distance-inl.cuh | 3 + .../dispatch_dice_double_double_double_int.cu | 51 +++++++++++ .../dispatch_dice_float_float_float_int.cu | 51 +++++++++++ cpp/src/distance/distance.cu | 50 ++++++++++- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_dice.cu | 48 +++++++++++ cpp/test/distance/distance_base.cuh | 35 +++++++- 13 files changed, 449 insertions(+), 8 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/dice.cuh create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu create mode 100644 cpp/test/distance/dist_dice.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 39472cae67..fe9132b223 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -302,6 +302,8 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index a5c8c0ef4b..8e1059df9b 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,7 @@ using distance_tag = std::integral_constant; * They are implemented below. The documentation of this function serves as * documentation for all functions. The following overloads are defined: * + * - DistanceType::DiceExpanded: * - DistanceType::Canberra: * - DistanceType::CorrelationExpanded: * - DistanceType::CosineExpanded: @@ -192,6 +194,70 @@ void distance_impl(raft::resources const& handle, corr_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + bool is_row_major, + DataT) // unused +{ + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), + "OutT can be uint8_t, float, double," + "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); + + ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + DataT* x_norm = workspace; + DataT* y_norm = workspace; + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if (x == y && is_row_major) { + raft::linalg::rowNorm( + x_norm, x, k, std::max(m, n), raft::linalg::L1Norm, is_row_major, stream, {}); + } else { + y_norm += m; + raft::linalg::reduce(x_norm, + x, + k, + m, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + raft::linalg::reduce(y_norm, + y, + k, + n, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + } + + ops::dice_distance_op distance_op{}; + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + template void distance_impl(raft::resources const& handle, distance_tag distance_type, @@ -794,9 +860,11 @@ template size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) { - size_t worksize = 0; - constexpr bool is_allocated = (distanceType <= raft::distance::DistanceType::CosineExpanded) || - (distanceType == raft::distance::DistanceType::CorrelationExpanded); + size_t worksize = 0; + constexpr bool is_allocated = + (distanceType <= raft::distance::DistanceType::CosineExpanded) || + (distanceType == raft::distance::DistanceType::CorrelationExpanded) || + (distanceType == raft::distance::DistanceType::DiceExpanded); constexpr int numOfBuffers = (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; diff --git a/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh index 3e8f4e86fb..84eb3c705b 100644 --- a/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/distance/detail/distance_ops/dice.cuh b/cpp/include/raft/distance/detail/distance_ops/dice.cuh new file mode 100644 index 0000000000..b8651c5592 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/dice.cuh @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace raft::distance::detail::ops { + +// Epilogue operator for CUTLASS based kernel +template +struct dice_cutlass_op { + __device__ dice_cutlass_op() noexcept {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + return static_cast(1.0) - static_cast(2 * accVal / (aNorm + bNorm)); + } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + +/** + * @brief the expanded dice distance matrix calculation + * + * It computes the following equation: + * + * d(x, y) = 1 - 2*(x ⋅ y) / ( ||x||_1 + ||y||_1 ) + */ +template +struct dice_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + // Load norms of input data + static constexpr bool use_norms = true; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = 1.0 - (2 * acc[i][j] / (regxn[i] + regyn[j])); + } + } + } + + constexpr dice_cutlass_op get_cutlass_op() const + { + return dice_cutlass_op(); + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh index e1dc6f9b37..bced721ec8 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -120,6 +120,10 @@ instantiate_raft_distance_detail_pairwise_matrix_dispatch( raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); instantiate_raft_distance_detail_pairwise_matrix_dispatch( raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::dice_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int); instantiate_raft_distance_detail_pairwise_matrix_dispatch( raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); instantiate_raft_distance_detail_pairwise_matrix_dispatch( diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh index a634e8c995..2d41e029fe 100644 --- a/cpp/include/raft/distance/distance-ext.cuh +++ b/cpp/include/raft/distance/distance-ext.cuh @@ -204,6 +204,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, raft::identity_op, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); instantiate_raft_distance_distance( @@ -286,6 +290,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_distance( @@ -362,6 +370,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_distance( @@ -429,6 +441,10 @@ instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_getWorkspaceSize( @@ -547,6 +563,22 @@ instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineE double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, float, float, @@ -822,6 +854,22 @@ instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, float, float, diff --git a/cpp/include/raft/distance/distance-inl.cuh b/cpp/include/raft/distance/distance-inl.cuh index 647c5b2908..13c9d57efd 100644 --- a/cpp/include/raft/distance/distance-inl.cuh +++ b/cpp/include/raft/distance/distance-inl.cuh @@ -306,6 +306,9 @@ void pairwise_distance(raft::resources const& handle, case DistanceType::RusselRaoExpanded: dispatch(std::integral_constant{}); break; + case DistanceType::DiceExpanded: + dispatch(std::integral_constant{}); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu new file mode 100644 index 0000000000..a259f8b3b0 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu new file mode 100644 index 0000000000..e89f8b422c --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::dice_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu index 8c94608311..8fe0bf2007 100644 --- a/cpp/src/distance/distance.cu +++ b/cpp/src/distance/distance.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,6 +72,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, raft::identity_op, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); instantiate_raft_distance_distance( @@ -154,6 +158,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_distance( @@ -230,6 +238,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_distance( @@ -297,6 +309,10 @@ instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_getWorkspaceSize( @@ -415,6 +431,22 @@ instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineE double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, float, float, @@ -690,6 +722,22 @@ instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, float, float, diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ff0518a4d0..ed923fb1db 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -156,6 +156,7 @@ if(BUILD_TESTS) distance/dist_canberra.cu distance/dist_correlation.cu distance/dist_cos.cu + distance/dist_dice.cu distance/dist_hamming.cu distance/dist_hellinger.cu distance/dist_inner_product.cu diff --git a/cpp/test/distance/dist_dice.cu b/cpp/test/distance/dist_dice.cu new file mode 100644 index 0000000000..6606dd6439 --- /dev/null +++ b/cpp/test/distance/dist_dice.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceExpDice : public DistanceTest { +}; + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceExpDice DistanceExpDiceD; +TEST_P(DistanceExpDiceD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceD, ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 2854a8f3df..a52baa02fa 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -14,8 +14,6 @@ * limitations under the License. */ -#include "../test_utils.cuh" - #include // common::nvtx::range #include // make_device_matrix_view #include // raft::sqrt @@ -29,6 +27,8 @@ #include +#include + namespace raft { namespace distance { @@ -96,6 +96,34 @@ RAFT_KERNEL naiveL1_Linf_CanberraDistanceKernel(DataType* dist, dist[outidx] = acc; } +template +RAFT_KERNEL naiveDiceDistanceKernel( + DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) +{ + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { return; } + + DataType acc_a = DataType(0); + DataType acc_b = DataType(0); + DataType acc_ab = DataType(0); + + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc_a += a; + acc_b += b; + acc_ab += a * b; + } + + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + + // Use 1.0 - (dice similarity) to calc the distance + dist[outidx] = (DataType)1.0 - (2 * acc_ab / ((acc_a) + (acc_b))); +} + template RAFT_KERNEL naiveCosineDistanceKernel( DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) @@ -391,6 +419,9 @@ void naiveDistance(DataType* dist, naiveCorrelationDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::DiceExpanded: + naiveDiceDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + break; default: FAIL() << "should be here\n"; } RAFT_CUDA_TRY(cudaPeekAtLastError()); From 92afe3227889067fb7dbef8f27bcf5707274bbd7 Mon Sep 17 00:00:00 2001 From: aamijar Date: Thu, 13 Jun 2024 01:38:21 +0000 Subject: [PATCH 02/11] update test --- cpp/include/raft/distance/detail/distance.cuh | 1 - cpp/test/distance/dist_dice.cu | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 8e1059df9b..08815c7ab5 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/test/distance/dist_dice.cu b/cpp/test/distance/dist_dice.cu index 6606dd6439..617a165e54 100644 --- a/cpp/test/distance/dist_dice.cu +++ b/cpp/test/distance/dist_dice.cu @@ -24,7 +24,7 @@ template class DistanceExpDice : public DistanceTest { }; -const std::vector> inputsd = { +const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, {0.001, 32, 1024, 1024, true, 1234ULL}, @@ -34,7 +34,7 @@ const std::vector> inputsd = { {0.001f, 32, 1024, 1024, false, 1234ULL}, {0.003f, 1024, 1024, 1024, false, 1234ULL}, }; -typedef DistanceExpDice DistanceExpDiceD; +typedef DistanceExpDice DistanceExpDiceD; TEST_P(DistanceExpDiceD, Result) { int m = params.isRowMajor ? params.m : params.n; From 2a93e4e520cf618c85e89a2729f8598914b30512 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 17 Jun 2024 03:32:04 +0000 Subject: [PATCH 03/11] fix order --- cpp/include/raft/distance/detail/distance.cuh | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 08815c7ab5..a329ff2ce4 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -52,10 +52,10 @@ using distance_tag = std::integral_constant; * They are implemented below. The documentation of this function serves as * documentation for all functions. The following overloads are defined: * - * - DistanceType::DiceExpanded: * - DistanceType::Canberra: * - DistanceType::CorrelationExpanded: * - DistanceType::CosineExpanded: + * - DistanceType::DiceExpanded: * - DistanceType::HammingUnexpanded: * - DistanceType::HellingerExpanded: * - DistanceType::JensenShannon: @@ -224,8 +224,17 @@ void distance_impl(raft::resources const& handle, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. if (x == y && is_row_major) { - raft::linalg::rowNorm( - x_norm, x, k, std::max(m, n), raft::linalg::L1Norm, is_row_major, stream, {}); + raft::linalg::reduce(x_norm, + x, + k, + std::max(m, n), + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); } else { y_norm += m; raft::linalg::reduce(x_norm, From 4b64c37e7db3acd2fb1d84076b16e4a0964dd884 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 17 Jun 2024 06:58:16 +0000 Subject: [PATCH 04/11] fix comments --- cpp/include/raft/distance/detail/distance_ops/dice.cuh | 2 +- cpp/test/distance/distance_base.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance_ops/dice.cuh b/cpp/include/raft/distance/detail/distance_ops/dice.cuh index b8651c5592..edd7e42a8d 100644 --- a/cpp/include/raft/distance/detail/distance_ops/dice.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/dice.cuh @@ -36,7 +36,7 @@ struct dice_cutlass_op { * * It computes the following equation: * - * d(x, y) = 1 - 2*(x ⋅ y) / ( ||x||_1 + ||y||_1 ) + * d(x, y) = 1 - 2*(x ⋅ y) / ( Σ(x) + Σ(y) ) */ template struct dice_distance_op { diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index a52baa02fa..4629f3f4bf 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -120,7 +120,7 @@ RAFT_KERNEL naiveDiceDistanceKernel( int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; - // Use 1.0 - (dice similarity) to calc the distance + // Use 1.0 - (dice dissimilarity) to calc the distance dist[outidx] = (DataType)1.0 - (2 * acc_ab / ((acc_a) + (acc_b))); } From a1ecfcb98403cfc754de73ca9eb9927ff2b61525 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 17 Jun 2024 16:55:29 +0000 Subject: [PATCH 05/11] fix order --- cpp/include/raft/distance/detail/distance.cuh | 86 +++++++++---------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index a329ff2ce4..b708360074 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -195,7 +195,7 @@ void distance_impl(raft::resources const& handle, template void distance_impl(raft::resources const& handle, - distance_tag distance_type, + distance_tag distance_type, const DataT* x, const DataT* y, OutT* out, @@ -224,51 +224,24 @@ void distance_impl(raft::resources const& handle, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. if (x == y && is_row_major) { - raft::linalg::reduce(x_norm, - x, - k, - std::max(m, n), - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); + raft::linalg::rowNorm( + x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } else { y_norm += m; - raft::linalg::reduce(x_norm, - x, - k, - m, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - raft::linalg::reduce(y_norm, - y, - k, - n, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); + raft::linalg::rowNorm( + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } - ops::dice_distance_op distance_op{}; + ops::cosine_distance_op distance_op{}; pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template void distance_impl(raft::resources const& handle, - distance_tag distance_type, + distance_tag distance_type, const DataT* x, const DataT* y, OutT* out, @@ -297,17 +270,44 @@ void distance_impl(raft::resources const& handle, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. if (x == y && is_row_major) { - raft::linalg::rowNorm( - x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + raft::linalg::reduce(x_norm, + x, + k, + std::max(m, n), + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); } else { y_norm += m; - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - raft::linalg::rowNorm( - y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + raft::linalg::reduce(x_norm, + x, + k, + m, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + raft::linalg::reduce(y_norm, + y, + k, + n, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); } - ops::cosine_distance_op distance_op{}; + ops::dice_distance_op distance_op{}; pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } From 5cddd74355be6b8d6e91153cac187571cc936c44 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 17 Jun 2024 17:09:46 +0000 Subject: [PATCH 06/11] fix includes --- cpp/test/distance/distance_base.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 4629f3f4bf..89c6a17580 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "../test_utils.h" + #include // common::nvtx::range #include // make_device_matrix_view #include // raft::sqrt @@ -27,8 +29,6 @@ #include -#include - namespace raft { namespace distance { From 546c29b56a7bb3a890c3ab6b61ac22844c512eb3 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 17 Jun 2024 17:11:02 +0000 Subject: [PATCH 07/11] fix includes --- cpp/test/distance/distance_base.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 89c6a17580..8d8d0553a0 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.h" +#include "../test_utils.cuh" #include // common::nvtx::range #include // make_device_matrix_view From eb03c89bd06a6cd87bb8c57f862178731a63d480 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 17 Jun 2024 22:18:56 +0000 Subject: [PATCH 08/11] use binary inputs for gtest --- cpp/test/distance/dist_dice.cu | 82 +++++++++++++++++++++++++++-- cpp/test/distance/distance_base.cuh | 6 ++- 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/cpp/test/distance/dist_dice.cu b/cpp/test/distance/dist_dice.cu index 617a165e54..573c63b45b 100644 --- a/cpp/test/distance/dist_dice.cu +++ b/cpp/test/distance/dist_dice.cu @@ -24,7 +24,80 @@ template class DistanceExpDice : public DistanceTest { }; -const std::vector> inputsd = { +template +class DistanceExpDiceXequalY + : public DistanceTestSameBuffer {}; + +const std::vector> inputsf = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; + +const std::vector> inputsXeqYf = { + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, +}; + +const std::vector> inputsNaN = { + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}}; + +typedef DistanceExpDice DistanceExpDiceNaN; +TEST_P(DistanceExpDiceNaN, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_FALSE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceNaN, ::testing::ValuesIn(inputsNaN)); + +typedef DistanceExpDice DistanceExpDiceF; +TEST_P(DistanceExpDiceF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceF, ::testing::ValuesIn(inputsf)); + +typedef DistanceExpDiceXequalY DistanceExpDiceXequalYF; +TEST_P(DistanceExpDiceXequalYF, Result) +{ + int m = params.m; + int n = params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + n, + raft::CompareApprox(params.tolerance), + stream)); + n = params.isRowMajor ? m : m / 2; + m = params.isRowMajor ? m / 2 : m; + + ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m, + n, + raft::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceXequalYF, ::testing::ValuesIn(inputsXeqYf)); + +const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, {0.001, 32, 1024, 1024, true, 1234ULL}, @@ -34,15 +107,18 @@ const std::vector> inputsd = { {0.001f, 32, 1024, 1024, false, 1234ULL}, {0.003f, 1024, 1024, 1024, false, 1234ULL}, }; -typedef DistanceExpDice DistanceExpDiceD; +typedef DistanceExpDice DistanceExpDiceD; TEST_P(DistanceExpDiceD, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceD, ::testing::ValuesIn(inputsd)); +class BigMatrixDice : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixDice, Result) {} + } // end namespace distance } // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 8d8d0553a0..f44fb18519 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -513,7 +513,8 @@ class DistanceTest : public ::testing::TestWithParam> { // Hellinger works only on positive numbers uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); - } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded) { + } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded || + distanceType == raft::distance::DistanceType::DiceExpanded) { uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); // Russel rao works on boolean values. @@ -602,7 +603,8 @@ class DistanceTestSameBuffer : public ::testing::TestWithParam Date: Mon, 24 Jun 2024 16:18:35 +0000 Subject: [PATCH 09/11] fix nan problem --- cpp/test/distance/dist_dice.cu | 22 +++++----------------- cpp/test/test_utils.h | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/cpp/test/distance/dist_dice.cu b/cpp/test/distance/dist_dice.cu index 573c63b45b..e127659dc6 100644 --- a/cpp/test/distance/dist_dice.cu +++ b/cpp/test/distance/dist_dice.cu @@ -34,6 +34,7 @@ const std::vector> inputsf = { {0.001f, 1024, 32, 1024, true, 1234ULL}, {0.001f, 32, 1024, 1024, true, 1234ULL}, {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, {0.001f, 1024, 1024, 32, false, 1234ULL}, {0.001f, 1024, 32, 1024, false, 1234ULL}, {0.001f, 32, 1024, 1024, false, 1234ULL}, @@ -51,26 +52,13 @@ const std::vector> inputsXeqYf = { {0.03f, 1024, 1024, 1024, false, 1234ULL}, }; -const std::vector> inputsNaN = { - {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}}; - -typedef DistanceExpDice DistanceExpDiceNaN; -TEST_P(DistanceExpDiceNaN, Result) -{ - int m = params.isRowMajor ? params.m : params.n; - int n = params.isRowMajor ? params.n : params.m; - ASSERT_FALSE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceNaN, ::testing::ValuesIn(inputsNaN)); - typedef DistanceExpDice DistanceExpDiceF; TEST_P(DistanceExpDiceF, Result) { int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); + dist_ref.data(), dist.data(), m, n, raft::CompareApproxNaN(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceF, ::testing::ValuesIn(inputsf)); @@ -83,7 +71,7 @@ TEST_P(DistanceExpDiceXequalYF, Result) dist[0].data(), m, n, - raft::CompareApprox(params.tolerance), + raft::CompareApproxNaN(params.tolerance), stream)); n = params.isRowMajor ? m : m / 2; m = params.isRowMajor ? m / 2 : m; @@ -92,7 +80,7 @@ TEST_P(DistanceExpDiceXequalYF, Result) dist[1].data(), m, n, - raft::CompareApprox(params.tolerance), + raft::CompareApproxNaN(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceXequalYF, ::testing::ValuesIn(inputsXeqYf)); @@ -113,7 +101,7 @@ TEST_P(DistanceExpDiceD, Result) int m = params.isRowMajor ? params.m : params.n; int n = params.isRowMajor ? params.n : params.m; ASSERT_TRUE(devArrMatch( - dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); + dist_ref.data(), dist.data(), m, n, raft::CompareApproxNaN(params.tolerance), stream)); } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/test_utils.h b/cpp/test/test_utils.h index cf9a885cfb..63ec74d1da 100644 --- a/cpp/test/test_utils.h +++ b/cpp/test/test_utils.h @@ -55,6 +55,23 @@ struct CompareApprox { T eps; }; +template +struct CompareApproxNaN { + CompareApproxNaN(T eps_) : eps(eps_) {} + bool operator()(const T& a, const T& b) const + { + T diff = std::abs(a - b); + T m = std::max(std::abs(a), std::abs(b)); + T ratio = diff > eps ? diff / m : diff; + + if (std::isnan(a) && std::isnan(b)) { return true; } + return (ratio <= eps); + } + + private: + T eps; +}; + template ::std::ostream& operator<<(::std::ostream& os, const raft::KeyValuePair& kv) { From 827fcf71dca4faefa723e6e4a4e8ab6aeeb41945 Mon Sep 17 00:00:00 2001 From: aamijar Date: Fri, 28 Jun 2024 00:03:13 +0000 Subject: [PATCH 10/11] binarize input --- cpp/include/raft/distance/detail/distance.cuh | 28 ++++--------------- .../distance/detail/distance_ops/dice.cuh | 5 +++- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index b708360074..a39dbf6700 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -279,32 +279,14 @@ void distance_impl(raft::resources const& handle, true, stream, false, - raft::identity_op(), + raft::nz_op(), raft::add_op()); } else { y_norm += m; - raft::linalg::reduce(x_norm, - x, - k, - m, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - raft::linalg::reduce(y_norm, - y, - k, - n, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); + raft::linalg::reduce( + x_norm, x, k, m, (AccT)0, is_row_major, true, stream, false, raft::nz_op(), raft::add_op()); + raft::linalg::reduce( + y_norm, y, k, n, (AccT)0, is_row_major, true, stream, false, raft::nz_op(), raft::add_op()); } ops::dice_distance_op distance_op{}; diff --git a/cpp/include/raft/distance/detail/distance_ops/dice.cuh b/cpp/include/raft/distance/detail/distance_ops/dice.cuh index edd7e42a8d..362ba7eab7 100644 --- a/cpp/include/raft/distance/detail/distance_ops/dice.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/dice.cuh @@ -58,7 +58,10 @@ struct dice_distance_op { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += (x != DataT(0) ? DataT(1) : DataT(0)) * (y != DataT(0) ? DataT(1) : DataT(0)); + }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], From 0fa6b670f0f877fc3ce20ea151c7fe31ac51aa85 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 1 Jul 2024 15:02:20 +0000 Subject: [PATCH 11/11] add dice to dispatch_00_generate.py --- .../detail/pairwise_matrix/dispatch_00_generate.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py index 97fe120458..6adff0eee1 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ # NOTE: this template is not perfectly formatted. Use pre-commit to get # everything in shape again. header = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -95,6 +95,11 @@ OpT="raft::distance::detail::ops::cosine_distance_op", archs = [60, 80], ), + dict( + path_prefix="dice", + OpT="raft::distance::detail::ops::dice_distance_op", + archs = [60, 80], + ), dict( path_prefix="hamming_unexpanded", OpT="raft::distance::detail::ops::hamming_distance_op",