diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index b4de76037a..bc9e09e5b0 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -90,10 +90,14 @@ inline void knn_merge_parts( RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), "in_keys and in_values must have the same shape."); RAFT_EXPECTS( - out_keys.extent(0) == out_values.extent(0) == n_samples, + out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == n_samples, "Number of rows in output keys and val matrices must equal number of rows in search matrix."); - RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1), - "Number of columns in output indices and distances matrices must be equal to k"); + RAFT_EXPECTS( + out_keys.extent(1) == out_values.extent(1) && out_keys.extent(1) == in_keys.extent(1), + "Number of columns in output indices and distances matrices must be equal to k"); + + idx_t* translations_ptr = nullptr; + if (translations.has_value()) { translations_ptr = translations.value().data_handle(); } auto n_parts = in_keys.extent(0) / n_samples; detail::knn_merge_parts(in_keys.data_handle(), @@ -104,7 +108,7 @@ inline void knn_merge_parts( n_parts, in_keys.extent(1), resource::get_cuda_stream(handle), - translations.value_or(nullptr)); + translations_ptr); } /** diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh index e2b5c41fb0..0a33832b79 100644 --- a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh +++ b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh @@ -30,8 +30,8 @@ template -__global__ void knn_merge_parts_kernel(value_t* inK, - value_idx* inV, +__global__ void knn_merge_parts_kernel(const value_t* inK, + const value_idx* inV, value_t* outK, value_idx* outV, size_t n_samples, @@ -65,8 +65,8 @@ __global__ void knn_merge_parts_kernel(value_t* inK, int col = i % k; - value_t* inKStart = inK + (row_idx + col); - value_idx* inVStart = inV + (row_idx + col); + const value_t* inKStart = inK + (row_idx + col); + const value_idx* inVStart = inV + (row_idx + col); int limit = Pow2::roundDown(total_k); value_idx translation = 0; @@ -99,8 +99,8 @@ __global__ void knn_merge_parts_kernel(value_t* inK, } template -inline void knn_merge_parts_impl(value_t* inK, - value_idx* inV, +inline void knn_merge_parts_impl(const value_t* inK, + const value_idx* inV, value_t* outK, value_idx* outV, size_t n_samples, @@ -137,8 +137,8 @@ inline void knn_merge_parts_impl(value_t* inK, * @param translations mapping of index offsets for each partition */ template -inline void knn_merge_parts(value_t* inK, - value_idx* inV, +inline void knn_merge_parts(const value_t* inK, + const value_idx* inV, value_t* outK, value_idx* outV, size_t n_samples, diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index 7b088316a3..3c089b1d22 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -50,8 +50,8 @@ namespace raft::spatial::knn { * @param translations */ template -inline void knn_merge_parts(value_t* in_keys, - idx_t* in_values, +inline void knn_merge_parts(const value_t* in_keys, + const idx_t* in_values, value_t* out_keys, idx_t* out_values, size_t n_samples,