From dda18100cde9b4e78a14e64ee9afb8f1e74b8140 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 23 Apr 2024 17:15:01 -0700 Subject: [PATCH] make helper static --- cpp/include/raft/neighbors/cagra.cuh | 3 +-- .../raft/neighbors/detail/cagra/cagra_build.cuh | 5 +---- cpp/include/raft/neighbors/ivf_pq_types.hpp | 15 +++++++++------ cpp/test/neighbors/ann_cagra.cuh | 3 +-- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 8fbd8e6ee6..c4406b66c5 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -55,8 +55,7 @@ namespace raft::neighbors::cagra { * @code{.cpp} * using namespace raft::neighbors; * // use default index parameters - * ivf_pq::index_params build_params; - * build_params.initialize_from_dataset(dataset); + * ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset); * ivf_pq::search_params search_params; * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index fad720a03b..f6bb0bf8fe 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -59,10 +59,7 @@ void build_knn_graph(raft::resources const& res, size_t(dataset.extent(1)), node_degree); - if (!build_params) { - build_params = ivf_pq::index_params{}; - build_params.value().initialize_from_dataset(dataset); - } + if (!build_params) { build_params = ivf_pq::index_params::from_dataset(dataset); } // Make model name const std::string model_name = [&]() { diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 5a3e6caa24..b590776222 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -109,16 +109,19 @@ struct index_params : ann::index_params { * Helper that sets values according to the extents of the dataset mdspan. */ template - void initialize_from_dataset( + static index_params from_dataset( mdspan, row_major, Accessor> dataset, raft::distance::DistanceType metric = raft::distance::L2Expanded) { - n_lists = + index_params params; + params.n_lists = dataset.extent(0) < 4 * 2500 ? 4 : static_cast(std::sqrt(dataset.extent(0))); - pq_dim = round_up_safe(static_cast(dataset.extent(1) / 4), static_cast(8)); - pq_bits = 8; - kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 0.1; - this->metric = metric; + params.pq_dim = + round_up_safe(static_cast(dataset.extent(1) / 4), static_cast(8)); + params.pq_bits = 8; + params.kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 0.1; + params.metric = metric; + return params; } }; diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 499abd7e26..715a94403f 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -417,8 +417,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam { raft::make_host_matrix(ps.n_rows, index_params.intermediate_graph_degree); if (ps.build_algo == graph_build_algo::IVF_PQ) { - auto build_params = ivf_pq::index_params{}; - build_params.initialize_from_dataset(database_view, ps.metric); + auto build_params = ivf_pq::index_params::from_dataset(database_view, ps.metric); if (ps.host_dataset) { cagra::build_knn_graph( handle_, database_host_view, knn_graph.view(), 2, build_params);