Skip to content

Commit

Permalink
[FEA] support of prefiltered brute force (#146)
Browse files Browse the repository at this point in the history
- The PR is one part of prefiltered brute force and should work with the PR of raft: rapidsai/raft#2294

Authors:
  - rhdong (https://github.com/rhdong)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #146
  • Loading branch information
rhdong authored May 29, 2024
1 parent 1912355 commit c533fe3
Show file tree
Hide file tree
Showing 11 changed files with 986 additions and 24 deletions.
27 changes: 27 additions & 0 deletions cpp/include/cuvs/core/bitmap.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/bitmap.hpp>

namespace cuvs::core {
/* To use bitmap functions containing CUDA code, include <raft/core/bitmap.cuh> */

template <typename bitmap_t, typename index_t>
using bitmap_view = raft::core::bitmap_view<bitmap_t, index_t>;

} // end namespace cuvs::core
5 changes: 4 additions & 1 deletion cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,15 @@ auto build(raft::resources const& handle,
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a
* given
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter);
/**
* @}
*/
Expand Down
1 change: 1 addition & 0 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/util/cudart_utils.hpp> // get_device_for_address
#include <raft/util/integer_utils.hpp> // rounding up

#include <cuvs/core/bitmap.hpp>
#include <cuvs/core/bitset.hpp>
#include <raft/core/detail/macros.hpp>

Expand Down
46 changes: 27 additions & 19 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "./detail/knn_brute_force.cuh"

#include <cuvs/neighbors/brute_force.hpp>

#include <raft/core/copy.hpp>
Expand Down Expand Up @@ -84,25 +85,32 @@ void index<T>::update_dataset(raft::resources const& res,
dataset_view_ = raft::make_const_mdspan(dataset_.view());
}

#define CUVS_INST_BFKNN(T) \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::distance::DistanceType metric, \
T metric_arg) \
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
\
void search(raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T>& idx, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<T, int64_t, raft::row_major> distances) \
{ \
detail::brute_force_search<T, int64_t>(res, idx, queries, neighbors, distances); \
} \
\
#define CUVS_INST_BFKNN(T) \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::distance::DistanceType metric, \
T metric_arg) \
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
\
void search( \
raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T>& idx, \
raft::device_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<T, int64_t, raft::row_major> distances, \
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter = std::nullopt) \
{ \
if (!sample_filter.has_value()) { \
detail::brute_force_search<T, int64_t>(res, idx, queries, neighbors, distances); \
} else { \
detail::brute_force_search_filtered<T, int64_t>( \
res, idx, queries, *sample_filter, neighbors, distances); \
} \
} \
\
template struct cuvs::neighbors::brute_force::index<T>;

CUVS_INST_BFKNN(float);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/brute_force_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void _search(cuvsResources_t res,
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);

cuvs::neighbors::brute_force::search(
*res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds);
*res_ptr, *index_ptr, queries_mds, neighbors_mds, distances_mds, std::nullopt);
}

} // namespace
Expand Down
202 changes: 201 additions & 1 deletion cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/brute_force.hpp>

Expand All @@ -23,16 +24,26 @@
#include "./fused_l2_knn.cuh"
#include "./haversine_distance.cuh"
#include "./knn_merge_parts.cuh"
#include "./knn_utils.cuh"

#include <raft/core/bitmap.cuh>
#include <raft/core/detail/popc.cuh>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cuda_stream_pool.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/init.cuh>
#include <raft/matrix/select_k.cuh>
#include <raft/sparse/convert/coo.cuh>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/distance/detail/utils.cuh>
#include <raft/sparse/linalg/sddmm.hpp>
#include <raft/sparse/matrix/select_k.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -65,7 +76,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
size_t max_row_tile_size = 0,
size_t max_col_tile_size = 0,
const ElementType* precomputed_index_norms = nullptr,
const ElementType* precomputed_search_norms = nullptr)
const ElementType* precomputed_search_norms = nullptr,
const uint32_t* filter_bitmap = nullptr)
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -214,6 +226,27 @@ void tiled_brute_force_knn(const raft::resources& handle,
});
}

if (filter_bitmap != nullptr) {
auto distances_ptr = temp_distances.data();
auto count = thrust::make_counting_iterator<IndexType>(0);
ElementType masked_distance = select_min ? std::numeric_limits<ElementType>::infinity()
: std::numeric_limits<ElementType>::lowest();
thrust::for_each(raft::resource::get_thrust_policy(handle),
count,
count + current_query_size * current_centroid_size,
[=] __device__(IndexType idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
IndexType g_idx = row * n + col;
IndexType item_idx = (g_idx) >> 5;
uint32_t bit_idx = (g_idx)&31;
uint32_t filter = filter_bitmap[item_idx];
if ((filter & (uint32_t(1) << bit_idx)) == 0) {
distances_ptr[idx] = masked_distance;
}
});
}

raft::matrix::select_k<ElementType, IndexType>(
handle,
raft::make_device_matrix_view<const ElementType, int64_t, raft::row_major>(
Expand Down Expand Up @@ -519,6 +552,173 @@ void brute_force_search(
query_norms ? query_norms->data_handle() : nullptr);
}

template <typename T, typename IdxT, typename BitmapT>
void brute_force_search_filtered(
raft::resources const& res,
const cuvs::neighbors::brute_force::index<T>& idx,
raft::device_matrix_view<const T, IdxT, raft::row_major> queries,
cuvs::core::bitmap_view<const BitmapT, IdxT> filter,
raft::device_matrix_view<IdxT, IdxT, raft::row_major> neighbors,
raft::device_matrix_view<T, IdxT, raft::row_major> distances,
std::optional<raft::device_vector_view<const T, IdxT>> query_norms = std::nullopt)
{
auto metric = idx.metric();

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1),
"Number of columns in queries must match brute force index");
RAFT_EXPECTS(metric == cuvs::distance::DistanceType::InnerProduct ||
metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded,
"Only Euclidean, IP, and Cosine are supported!");

RAFT_EXPECTS(idx.has_norms() || !(metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded),
"Index must has norms when using Euclidean, IP, and Cosine!");

IdxT n_queries = queries.extent(0);
IdxT n_dataset = idx.dataset().extent(0);
IdxT dim = idx.dataset().extent(1);
IdxT k = neighbors.extent(1);

auto stream = raft::resource::get_cuda_stream(res);

// calc nnz
IdxT nnz_h = 0;
rmm::device_scalar<IdxT> nnz(0, stream);
auto nnz_view = raft::make_device_scalar_view<IdxT>(nnz.data());
auto filter_view =
raft::make_device_vector_view<const BitmapT, IdxT>(filter.data(), filter.n_elements());

// TODO(rhdong): Need to switch to the public API,
// with the issue: https://github.com/rapidsai/cuvs/issues/158
raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view);
raft::copy(&nnz_h, nnz.data(), 1, stream);

raft::resource::sync_stream(res, stream);
float sparsity = (1.0f * nnz_h / (1.0f * n_queries * n_dataset));

if (sparsity > 0.01f) {
raft::resources stream_pool_handle(res);
raft::resource::set_cuda_stream(stream_pool_handle, stream);
auto idx_norm = idx.has_norms() ? const_cast<T*>(idx.norms().data_handle()) : nullptr;

tiled_brute_force_knn<T, IdxT>(stream_pool_handle,
queries.data_handle(),
idx.dataset().data_handle(),
n_queries,
n_dataset,
dim,
k,
distances.data_handle(),
neighbors.data_handle(),
metric,
2.0,
0,
0,
idx_norm,
nullptr,
filter.data());
} else {
auto csr = raft::make_device_csr_matrix<T, IdxT>(res, n_queries, n_dataset, nnz_h);

// fill csr
raft::sparse::convert::bitmap_to_csr(res, filter, csr);

// create filter csr view
auto compressed_csr_view = csr.structure_view();
rmm::device_uvector<IdxT> rows(compressed_csr_view.get_nnz(), stream);
raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(),
compressed_csr_view.get_n_rows(),
rows.data(),
compressed_csr_view.get_nnz(),
stream);
if (n_queries > 10) {
auto csr_view = raft::make_device_csr_matrix_view<T, IdxT, IdxT, IdxT>(
csr.get_elements().data(), compressed_csr_view);

// create dataset view
auto dataset_view = raft::make_device_matrix_view<const T, IdxT, raft::col_major>(
idx.dataset().data_handle(), dim, n_dataset);

// calc dot
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
raft::sparse::linalg::sddmm(res,
queries,
dataset_view,
csr_view,
raft::linalg::Operation::NON_TRANSPOSE,
raft::linalg::Operation::NON_TRANSPOSE,
raft::make_host_scalar_view<T>(&alpha),
raft::make_host_scalar_view<T>(&beta));
} else {
raft::sparse::distance::detail::faster_dot_on_csr(res,
csr.get_elements().data(),
compressed_csr_view.get_nnz(),
compressed_csr_view.get_indptr().data(),
compressed_csr_view.get_indices().data(),
queries.data_handle(),
idx.dataset().data_handle(),
compressed_csr_view.get_n_rows(),
dim);
}

// post process
std::optional<raft::device_vector<T, IdxT>> query_norms_;
if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded) {
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
if (!query_norms) {
query_norms_ = raft::make_device_vector<T, IdxT>(res, n_queries);
raft::linalg::rowNorm((T*)(query_norms_->data_handle()),
queries.data_handle(),
dim,
n_queries,
raft::linalg::L2Norm,
true,
stream,
raft::sqrt_op{});
}
} else {
if (!query_norms) {
query_norms_ = raft::make_device_vector<T, IdxT>(res, n_queries);
raft::linalg::rowNorm((T*)(query_norms_->data_handle()),
queries.data_handle(),
dim,
n_queries,
raft::linalg::L2Norm,
true,
stream,
raft::identity_op{});
}
}
cuvs::neighbors::detail::epilogue_on_csr(
res,
csr.get_elements().data(),
compressed_csr_view.get_nnz(),
rows.data(),
compressed_csr_view.get_indices().data(),
query_norms ? query_norms->data_handle() : query_norms_->data_handle(),
idx.norms().data_handle(),
metric);
}

// select k
auto const_csr_view = raft::make_device_csr_matrix_view<const T, IdxT, IdxT, IdxT>(
csr.get_elements().data(), compressed_csr_view);
std::optional<raft::device_vector_view<const IdxT, IdxT>> no_opt = std::nullopt;
bool select_min = cuvs::distance::is_min_close(metric);
raft::sparse::matrix::select_k(
res, const_csr_view, no_opt, distances, neighbors, select_min, true);
}

return;
}

template <typename T>
cuvs::neighbors::brute_force::index<T> build(
raft::resources const& res,
Expand Down
Loading

0 comments on commit c533fe3

Please sign in to comment.