Skip to content

Commit

Permalink
make helper static
Browse files Browse the repository at this point in the history
  • Loading branch information
tarang-jain committed Apr 24, 2024
1 parent 54061d0 commit dda1810
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 14 deletions.
3 changes: 1 addition & 2 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ namespace raft::neighbors::cagra {
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_pq::index_params build_params;
* build_params.initialize_from_dataset(dataset);
* ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset);
* ivf_pq::search_params search_params;
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
Expand Down
5 changes: 1 addition & 4 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ void build_knn_graph(raft::resources const& res,
size_t(dataset.extent(1)),
node_degree);

if (!build_params) {
build_params = ivf_pq::index_params{};
build_params.value().initialize_from_dataset(dataset);
}
if (!build_params) { build_params = ivf_pq::index_params::from_dataset(dataset); }

// Make model name
const std::string model_name = [&]() {
Expand Down
15 changes: 9 additions & 6 deletions cpp/include/raft/neighbors/ivf_pq_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,19 @@ struct index_params : ann::index_params {
* Helper that sets values according to the extents of the dataset mdspan.
*/
template <typename DataT, typename Accessor>
void initialize_from_dataset(
static index_params from_dataset(
mdspan<const DataT, matrix_extent<int64_t>, row_major, Accessor> dataset,
raft::distance::DistanceType metric = raft::distance::L2Expanded)
{
n_lists =
index_params params;
params.n_lists =
dataset.extent(0) < 4 * 2500 ? 4 : static_cast<uint32_t>(std::sqrt(dataset.extent(0)));
pq_dim = round_up_safe(static_cast<uint32_t>(dataset.extent(1) / 4), static_cast<uint32_t>(8));
pq_bits = 8;
kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 0.1;
this->metric = metric;
params.pq_dim =
round_up_safe(static_cast<uint32_t>(dataset.extent(1) / 4), static_cast<uint32_t>(8));
params.pq_bits = 8;
params.kmeans_trainset_fraction = dataset.extent(0) < 10000 ? 1 : 0.1;
params.metric = metric;
return params;
}
};

Expand Down
3 changes: 1 addition & 2 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {
raft::make_host_matrix<IdxT, int64_t>(ps.n_rows, index_params.intermediate_graph_degree);

if (ps.build_algo == graph_build_algo::IVF_PQ) {
auto build_params = ivf_pq::index_params{};
build_params.initialize_from_dataset(database_view, ps.metric);
auto build_params = ivf_pq::index_params::from_dataset(database_view, ps.metric);
if (ps.host_dataset) {
cagra::build_knn_graph<DataT, IdxT>(
handle_, database_host_view, knn_graph.view(), 2, build_params);
Expand Down

0 comments on commit dda1810

Please sign in to comment.