From 706eb3928f03063f791119e877c8f726f8283a8c Mon Sep 17 00:00:00 2001 From: Jinsol Park Date: Mon, 22 Jul 2024 07:19:13 -0700 Subject: [PATCH] Use slicing kernel to copy distances inside NN Descent (#2380) This make use of raft's slicing kernel within NN Descent build. I found that my previous implementation was inefficient (merged in [this PR](https://github.com/rapidsai/raft/pull/2345)). ### Improvements Time to call NN Descent `build()` with `return_distances=True` before and after using this kernel | Dataset | Before |After| | ------------- | ------------- |---| | mnist (60000, 784) | 1550ms | 1020ms| | sift (1M, 128) | 11342ms |5546ms| | gist (1M, 960) | 13508ms |9278ms| Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Micka (https://github.com/lowener) URL: https://github.com/rapidsai/raft/pull/2380 --- .../raft/neighbors/detail/nn_descent.cuh | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index ad16f3e11d..9c37ee146d 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -19,11 +19,14 @@ #include "../nn_descent_types.hpp" #include +#include #include #include +#include #include #include #include +#include #include #include #include // raft::util::arch::SM_* @@ -1365,12 +1368,22 @@ void GNND::build(Data_t* data, static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); if (return_distances) { - for (size_t i = 0; i < (size_t)nrow_; i++) { - raft::copy(output_distances + i * build_config_.output_graph_degree, - graph_.h_dists.data_handle() + i * build_config_.node_degree, - build_config_.output_graph_degree, - raft::resource::get_cuda_stream(res)); - } + auto graph_d_dists = raft::make_device_matrix( + res, nrow_, build_config_.node_degree); + raft::copy(graph_d_dists.data_handle(), + graph_.h_dists.data_handle(), + nrow_ * build_config_.node_degree, + raft::resource::get_cuda_stream(res)); + + auto output_dist_view = raft::make_device_matrix_view( + output_distances, nrow_, build_config_.output_graph_degree); + + raft::matrix::slice_coordinates coords{static_cast(0), + static_cast(0), + static_cast(nrow_), + static_cast(build_config_.output_graph_degree)}; + raft::matrix::slice( + res, raft::make_const_mdspan(graph_d_dists.view()), output_dist_view, coords); } Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle();