Skip to content

Commit

Permalink
feat: do flat search if too many rows are filtered out (#2583)
Browse files Browse the repository at this point in the history
- replace RoaringBitmap with BitVec
- use pool to avoid allocating bitset for each query with prefilter
- fall back to flat search if too many rows filtered out
- prefetch with flat search

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
  • Loading branch information
BubbleCal authored and eddyxu committed Jul 11, 2024
1 parent 85bab4d commit a8603a1
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 25 deletions.
6 changes: 5 additions & 1 deletion rust/lance-index/src/vector/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ impl<'a> Visited<'a> {
let node_id_usize = node_id as usize;
self.visited[node_id_usize]
}

pub fn count_ones(&self) -> usize {
self.visited.count_ones()
}
}

impl<'a> Drop for Visited<'a> {
Expand Down Expand Up @@ -230,7 +234,7 @@ pub fn beam_search(
ep: &OrderedNode,
k: usize,
dist_calc: &impl DistCalculator,
bitset: Option<&roaring::bitmap::RoaringBitmap>,
bitset: Option<&Visited>,
prefetch_distance: Option<usize>,
visited: &mut Visited,
) -> Vec<OrderedNode> {
Expand Down
87 changes: 63 additions & 24 deletions rust/lance-index/src/vector/hnsw/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ use itertools::Itertools;

use lance_linalg::distance::DistanceType;
use rayon::prelude::*;
use roaring::RoaringBitmap;
use snafu::{location, Location};
use std::cmp::min;
use std::collections::HashMap;
use std::collections::{BinaryHeap, HashMap};
use std::fmt::Debug;
use std::iter;
use std::sync::atomic::{AtomicUsize, Ordering};
Expand All @@ -32,7 +31,7 @@ use super::{select_neighbors_heuristic, HnswMetadata, HNSW_TYPE, VECTOR_ID_COL,
use crate::prefilter::PreFilter;
use crate::vector::flat::storage::FlatStorage;
use crate::vector::graph::builder::GraphBuilderNode;
use crate::vector::graph::greedy_search;
use crate::vector::graph::{greedy_search, Visited};
use crate::vector::graph::{
Graph, OrderedFloat, OrderedNode, VisitedGenerator, DISTS_FIELD, NEIGHBORS_COL, NEIGHBORS_FIELD,
};
Expand Down Expand Up @@ -164,7 +163,7 @@ impl HNSW {
query: ArrayRef,
k: usize,
ef: usize,
bitset: Option<RoaringBitmap>,
bitset: Option<Visited>,
visited_generator: &mut VisitedGenerator,
storage: &impl VectorStore,
prefetch_distance: Option<usize>,
Expand Down Expand Up @@ -198,7 +197,7 @@ impl HNSW {
query: ArrayRef,
k: usize,
ef: usize,
bitset: Option<RoaringBitmap>,
bitset: Option<Visited>,
storage: &impl VectorStore,
) -> Result<Vec<OrderedNode>> {
let mut visited_generator = self
Expand Down Expand Up @@ -572,7 +571,7 @@ impl IvfSubIndex for HNSW {
}

let visited_generator_queue = Arc::new(ArrayQueue::new(num_cpus::get() * 2));
for _ in 0..(num_cpus::get() * 2) {
for _ in 0..num_cpus::get() * 2 {
visited_generator_queue
.push(VisitedGenerator::new(0))
.unwrap();
Expand Down Expand Up @@ -620,34 +619,74 @@ impl IvfSubIndex for HNSW {
storage: &impl VectorStore,
prefilter: Arc<dyn PreFilter>,
) -> Result<RecordBatch> {
let schema = VECTOR_RESULT_SCHEMA.clone();
if params.ef < k {
return Err(Error::Index {
message: "ef must be greater than or equal to k".to_string(),
location: location!(),
});
}

let schema = VECTOR_RESULT_SCHEMA.clone();
if self.is_empty() {
return Ok(RecordBatch::new_empty(schema));
}

let bitmap = if prefilter.is_empty() {
let mut prefilter_generator = self
.inner
.visited_generator_queue
.pop()
.unwrap_or_else(|| VisitedGenerator::new(storage.len()));
let prefilter_bitset = if prefilter.is_empty() {
None
} else {
let indices = prefilter.filter_row_ids(Box::new(storage.row_ids()));
Some(
RoaringBitmap::from_sorted_iter(indices.into_iter().map(|i| i as u32)).map_err(
|e| Error::Index {
message: format!("Error creating RoaringBitmap: {}", e),
location: location!(),
},
)?,
)
let mut bitset = prefilter_generator.generate(storage.len());
for indices in indices {
bitset.insert(indices as u32);
}
Some(bitset)
};

if params.ef < k {
return Err(Error::Index {
message: "ef must be greater than or equal to k".to_string(),
location: location!(),
});
}

let results = self.search_basic(query.clone(), k, params.ef, bitmap, storage)?;
let remained = prefilter_bitset
.as_ref()
.map(|b| b.count_ones())
.unwrap_or(storage.len());
let results = if remained < self.len() * 10 / 100 {
log::debug!("too many rows filtered, using flat search");
let prefilter_bitset =
prefilter_bitset.expect("the prefilter bitset must be set for flat search");
let node_ids = storage
.row_ids()
.enumerate()
.filter_map(|(node_id, _)| {
prefilter_bitset
.contains(node_id as u32)
.then_some(node_id as u32)
})
.collect_vec();
let dist_calc = storage.dist_calculator(query);
let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
for i in 0..node_ids.len() {
if let Some(ahead) = self.inner.params.prefetch_distance {
if i + ahead < node_ids.len() {
dist_calc.prefetch(node_ids[i + ahead]);
}
}
let node_id = node_ids[i];
let dist = dist_calc.distance(node_id).into();
if heap.len() < k {
heap.push((dist, node_id).into());
} else if dist < heap.peek().unwrap().dist {
heap.pop();
heap.push((dist, node_id).into());
}
}
heap.into_sorted_vec()
} else {
self.search_basic(query.clone(), k, params.ef, prefilter_bitset, storage)?
};
// if the queue is full, we just don't push it back, so ignore the error here
let _ = self.inner.visited_generator_queue.push(prefilter_generator);

let row_ids = UInt64Array::from_iter_values(results.iter().map(|x| storage.row_id(x.id)));
let distances = Arc::new(Float32Array::from_iter_values(
Expand Down

0 comments on commit a8603a1

Please sign in to comment.