diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index b708360074..a39dbf6700 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -279,32 +279,14 @@ void distance_impl(raft::resources const& handle, true, stream, false, - raft::identity_op(), + raft::nz_op(), raft::add_op()); } else { y_norm += m; - raft::linalg::reduce(x_norm, - x, - k, - m, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - raft::linalg::reduce(y_norm, - y, - k, - n, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); + raft::linalg::reduce( + x_norm, x, k, m, (AccT)0, is_row_major, true, stream, false, raft::nz_op(), raft::add_op()); + raft::linalg::reduce( + y_norm, y, k, n, (AccT)0, is_row_major, true, stream, false, raft::nz_op(), raft::add_op()); } ops::dice_distance_op distance_op{}; diff --git a/cpp/include/raft/distance/detail/distance_ops/dice.cuh b/cpp/include/raft/distance/detail/distance_ops/dice.cuh index edd7e42a8d..362ba7eab7 100644 --- a/cpp/include/raft/distance/detail/distance_ops/dice.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/dice.cuh @@ -58,7 +58,10 @@ struct dice_distance_op { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += (x != DataT(0) ? DataT(1) : DataT(0)) * (y != DataT(0) ? DataT(1) : DataT(0)); + }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py index 97fe120458..6adff0eee1 100644 --- a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ # NOTE: this template is not perfectly formatted. Use pre-commit to get # everything in shape again. header = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -95,6 +95,11 @@ OpT="raft::distance::detail::ops::cosine_distance_op", archs = [60, 80], ), + dict( + path_prefix="dice", + OpT="raft::distance::detail::ops::dice_distance_op", + archs = [60, 80], + ), dict( path_prefix="hamming_unexpanded", OpT="raft::distance::detail::ops::hamming_distance_op",