Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CAGRA-Q subspace dim = 4 support #2244

Merged
merged 8 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ void launch_vpq_search_main_core(
CagraSampleFilterT sample_filter)
{
RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2, "Only pq_len 2 is supported for now");
RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4,
"Only pq_len 2 or 4 is supported for now");
RAFT_EXPECTS(vpq_dset->dim() % vpq_dset->pq_dim() == 0,
"dim must be a multiple of pq_dim at the moment");

Expand Down
29 changes: 16 additions & 13 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
using CODE_BOOK_T = CODE_BOOK_T_;
using QUERY_T = typename dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T>::QUERY_T;

static_assert(std::is_same_v<CODE_BOOK_T, half>, "Only CODE_BOOK_T = `half` is supported now");

const std::uint8_t* encoded_dataset_ptr;
const std::uint32_t encoded_dataset_dim;
const std::uint32_t n_subspace;
Expand All @@ -53,18 +55,19 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
smem_pq_code_book_ptr = reinterpret_cast<CODE_BOOK_T*>(smem_ptr);

// Copy PQ table
if constexpr (std::is_same<CODE_BOOK_T, half>::value) {
for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) {
half2 buf2;
buf2.x = pq_code_book_ptr[i];
buf2.y = pq_code_book_ptr[i + 1];
(reinterpret_cast<half2*>(smem_pq_code_book_ptr + i))[0] = buf2;
}
} else {
for (unsigned i = threadIdx.x; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x) {
// TODO: vectorize
smem_pq_code_book_ptr[i] = pq_code_book_ptr[i];
}
for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) {
half2 buf2;
buf2.x = pq_code_book_ptr[i];
buf2.y = pq_code_book_ptr[i + 1];

// Change the order of PQ code book array to reduce the
// frequency of bank conflicts.
constexpr auto num_elements_per_bank = 4 / utils::size_of<CODE_BOOK_T>();
constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank;
const auto j = i / num_elements_per_bank;
const auto smem_index =
(j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS);
reinterpret_cast<half2*>(smem_pq_code_book_ptr)[smem_index] = buf2;
}
}

Expand Down Expand Up @@ -136,7 +139,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
4 + k));
}
//
if constexpr ((std::is_same<CODE_BOOK_T, half>::value) && (PQ_LEN % 2 == 0)) {
if constexpr (PQ_LEN % 2 == 0) {
// **** Use half2 for distance computation ****
half2 norm2{0, 0};
#pragma unroll
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& da
vpq_params r = params;
double n_rows = dataset.extent(0);
size_t dim = dataset.extent(1);
if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{2}); }
if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); }
if (r.pq_bits == 0) { r.pq_bits = 8; }
if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe<uint32_t>(std::sqrt(n_rows), 8); }
if (r.vq_kmeans_trainset_fraction == 0) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/neighbors/ann_cagra_vpq.cuh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class AnnCagraVpqTest : public ::testing::TestWithParam<AnnCagraVpqInputs> {
resource::sync_stream(handle_);
}

const auto vpq_k = ps.k * 16;
const auto vpq_k = ps.k * 4;
{
rmm::device_uvector<DistanceT> distances_dev(vpq_k * ps.n_queries, stream_);
rmm::device_uvector<IdxT> indices_dev(vpq_k * ps.n_queries, stream_);
Expand Down Expand Up @@ -319,7 +319,7 @@ const std::vector<AnnCagraVpqInputs> vpq_inputs = raft::util::itertools::product
{1000, 10000}, // n_rows
{128, 132, 192, 256, 512, 768}, // dim
{8, 12}, // k
{2}, // pq_len
{2, 4}, // pq_len
{8}, // pq_bits
{graph_build_algo::NN_DESCENT}, // build_algo
{search_algo::SINGLE_CTA, search_algo::MULTI_CTA}, // algo
Expand Down
Loading