Skip to content

Commit

Permalink
feat: integrate bitmap index into top-level lance APIs (#2575)
Browse files Browse the repository at this point in the history
This also fixes a minor synchronization perf issue in the prefilter
search path
  • Loading branch information
westonpace committed Jul 9, 2024
1 parent f94565f commit 49de38e
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 36 deletions.
142 changes: 134 additions & 8 deletions python/python/benchmarks/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import NamedTuple, Union

import lance
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pytest
Expand Down Expand Up @@ -34,7 +35,14 @@ def create_table(num_rows, offset) -> pa.Table:
values = pc.random(num_rows * N_DIMS).cast(pa.float32())
vectors = pa.FixedSizeListArray.from_arrays(values, N_DIMS)
filterable = pa.array(range(offset, offset + num_rows))
return pa.table({"vector": vectors, "filterable": filterable})
categories = pa.array(np.random.randint(0, 100, num_rows))
return pa.table(
{
"vector": vectors,
"filterable": filterable,
"category": categories,
}
)


def create_base_dataset(data_dir: Path) -> lance.LanceDataset:
Expand Down Expand Up @@ -66,8 +74,9 @@ def create_base_dataset(data_dir: Path) -> lance.LanceDataset:
)

dataset.create_scalar_index("filterable", "BTREE")
dataset.create_scalar_index("category", "BITMAP")

return dataset
return lance.dataset(tmp_path, index_cache_size=64 * 1024)


def create_delete_dataset(data_dir):
Expand All @@ -82,7 +91,7 @@ def create_delete_dataset(data_dir):
dataset = lance.dataset(tmp_path)
dataset.delete("filterable % 2 != 0")

return dataset
return lance.dataset(tmp_path, index_cache_size=64 * 1024)


def create_new_rows_dataset(data_dir):
Expand All @@ -98,7 +107,7 @@ def create_new_rows_dataset(data_dir):
table = create_table(NEW_ROWS, offset=NUM_ROWS)
dataset = lance.write_dataset(table, tmp_path, mode="append")

return dataset
return lance.dataset(tmp_path, index_cache_size=64 * 1024)


class Datasets(NamedTuple):
Expand Down Expand Up @@ -129,6 +138,8 @@ def test_knn_search(test_dataset, benchmark):
q = pc.random(N_DIMS).cast(pa.float32())
result = benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
Expand All @@ -141,10 +152,12 @@ def test_knn_search(test_dataset, benchmark):


@pytest.mark.benchmark(group="query_ann")
def test_flat_index_search(test_dataset, benchmark):
def test_ann_no_refine(test_dataset, benchmark):
q = pc.random(N_DIMS).cast(pa.float32())
result = benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
Expand All @@ -156,10 +169,12 @@ def test_flat_index_search(test_dataset, benchmark):


@pytest.mark.benchmark(group="query_ann")
def test_ivf_pq_index_search(test_dataset, benchmark):
def test_ann_with_refine(test_dataset, benchmark):
q = pc.random(N_DIMS).cast(pa.float32())
result = benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
Expand All @@ -180,6 +195,8 @@ def test_filtered_search(test_dataset, benchmark, selectivity, prefilter, use_in
threshold = int(round(selectivity * NUM_ROWS))
result = benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
Expand Down Expand Up @@ -221,11 +238,13 @@ def test_filtered_search(test_dataset, benchmark, selectivity, prefilter, use_in
"greater_than_not_selective",
],
)
def test_scalar_index_prefilter(test_dataset, benchmark, filter: str):
def test_btree_index_prefilter(test_dataset, benchmark, filter: str):
q = pc.random(N_DIMS).cast(pa.float32())
if filter is None:
benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
Expand All @@ -236,6 +255,8 @@ def test_scalar_index_prefilter(test_dataset, benchmark, filter: str):
else:
benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
Expand Down Expand Up @@ -275,14 +296,119 @@ def test_scalar_index_prefilter(test_dataset, benchmark, filter: str):
"greater_than_not_selective",
],
)
def test_scalar_index_search(test_dataset, benchmark, filter: str):
def test_btree_index_search(test_dataset, benchmark, filter: str):
if filter is None:
benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
)
else:
benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
prefilter=True,
filter=filter,
)


@pytest.mark.benchmark(group="query_ann")
@pytest.mark.parametrize(
"filter",
(
None,
"category = 0",
"category != 0",
"category IN (0)",
"category NOT IN (0)",
"category != 0 AND category != 3 AND category != 7",
"category NOT IN (0, 3, 7)",
"category < 5",
"category > 5",
),
ids=[
"none",
"equality",
"not_equality",
"in_list_one",
"not_in_list_one",
"not_equality_and_chain",
"not_in_list_three",
"less_than_selective",
"greater_than_not_selective",
],
)
def test_bitmap_index_prefilter(test_dataset, benchmark, filter: str):
q = pc.random(N_DIMS).cast(pa.float32())
if filter is None:
benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
k=100,
nprobes=10,
),
)
else:
benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
nearest=dict(
column="vector",
q=q,
k=100,
nprobes=10,
),
prefilter=True,
filter=filter,
)


@pytest.mark.benchmark(group="query_no_vec")
@pytest.mark.parametrize(
"filter",
(
None,
"category = 0",
"category != 0",
"category IN (0)",
"category IN (0, 3, 7)",
"category NOT IN (0)",
"category != 0 AND category != 3 AND category != 7",
"category NOT IN (0, 3, 7)",
"category < 5",
"category > 5",
),
ids=[
"none",
"equality",
"not_equality",
"in_list_one",
"in_list_three",
"not_in_list_one",
"not_equality_and_chain",
"not_in_list_three",
"less_than_selective",
"greater_than_not_selective",
],
)
def test_bitmap_index_search(test_dataset, benchmark, filter: str):
if filter is None:
benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
)
else:
benchmark(
test_dataset.to_table,
columns=[],
with_row_id=True,
prefilter=True,
filter=filter,
)
24 changes: 14 additions & 10 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def cleanup_old_versions(
def create_scalar_index(
self,
column: str,
index_type: Literal["BTREE"],
index_type: Union[Literal["BTREE"], Literal["BITMAP"]],
name: Optional[str] = None,
*,
replace: bool = True,
Expand Down Expand Up @@ -1159,9 +1159,15 @@ def create_scalar_index(
that use scalar indices will either have a ``ScalarIndexQuery`` relation or a
``MaterializeIndex`` operator.
Currently, the only type of scalar index available is ``BTREE``. This index
combines is inspired by the btree data structure although only the first few
layers of the btree are cached in memory.
There are two types of scalar indices available today. The most common
type is ``BTREE``. This index is inspired by the btree data structure
although only the first few layers of the btree are cached in memory. It iwll
perform well on columns with a large number of unique values and few rows per
value.
The other index type is ``BITMAP``. This index stores a bitmap for each unique
value in the column. This index is useful for columns with a small number of
unique values and many rows per value.
Note that the ``LANCE_BYPASS_SPILLING`` environment variable can be used to
bypass spilling to disk. Setting this to true can avoid memory exhaustion
Expand All @@ -1175,7 +1181,7 @@ def create_scalar_index(
The column to be indexed. Must be a boolean, integer, float,
or string column.
index_type : str
The type of the index. Only ``"BTREE"`` is supported now.
The type of the index. One of ``"BTREE"`` or ``"BITMAP"``.
name : str, optional
The index name. If not provided, it will be generated from the
column name.
Expand Down Expand Up @@ -1226,12 +1232,10 @@ def create_scalar_index(
)

index_type = index_type.upper()
if index_type != "BTREE":
if index_type not in ["BTREE", "BITMAP"]:
raise NotImplementedError(
(
'Only "BTREE" is supported for ',
f"index_type. Received {index_type}",
)
'Only "BTREE" or "BITMAP" are supported for ',
f"scalar columns. Received {index_type}",
)

self._ds.create_index([column], index_type, name, replace)
Expand Down
8 changes: 7 additions & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use lance::dataset::{
WriteParams,
};
use lance::dataset::{BatchInfo, BatchUDF, NewColumnTransform, UDFCheckpointStore};
use lance::index::scalar::ScalarIndexType;
use lance::index::{scalar::ScalarIndexParams, vector::VectorIndexParams};
use lance_arrow::as_fixed_size_list_array;
use lance_core::datatypes::Schema;
Expand Down Expand Up @@ -920,7 +921,7 @@ impl Dataset {
) -> PyResult<()> {
let index_type = index_type.to_uppercase();
let idx_type = match index_type.as_str() {
"BTREE" => IndexType::Scalar,
"BTREE" | "BITMAP" => IndexType::Scalar,
"IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector,
_ => {
return Err(PyValueError::new_err(format!(
Expand All @@ -932,6 +933,11 @@ impl Dataset {
// Only VectorParams are supported.
let params: Box<dyn IndexParams> = if index_type == "BTREE" {
Box::<ScalarIndexParams>::default()
} else if index_type == "BITMAP" {
Box::new(ScalarIndexParams {
// Temporary workaround until we add support for auto-detection of scalar index type
force_index_type: Some(ScalarIndexType::Bitmap),
})
} else {
let column_type = match self.ds.schema().field(columns[0]) {
Some(f) => f.data_type().clone(),
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-index/src/scalar/bitmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{Index, IndexType};
use super::btree::OrderableScalarValue;
use super::{btree::BtreeTrainingSource, IndexStore, ScalarIndex, ScalarQuery};

const BITMAP_LOOKUP_NAME: &str = "bitmap_page_lookup.lance";
pub const BITMAP_LOOKUP_NAME: &str = "bitmap_page_lookup.lance";

/// A scalar index that stores a bitmap for each possible value
///
Expand Down
10 changes: 9 additions & 1 deletion rust/lance/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use lance_table::format::Index as IndexMetadata;
use lance_table::format::{Fragment, SelfDescribingFileReader};
use lance_table::io::manifest::read_manifest_indexes;
use roaring::RoaringBitmap;
use scalar::ScalarIndexParams;
use serde_json::json;
use snafu::{location, Location};
use tracing::instrument;
Expand Down Expand Up @@ -201,7 +202,14 @@ impl DatasetIndexExt for Dataset {
let index_id = Uuid::new_v4();
match (index_type, params.index_name()) {
(IndexType::Scalar, LANCE_SCALAR_INDEX) => {
build_scalar_index(self, column, &index_id.to_string()).await?;
let params = params
.as_any()
.downcast_ref::<ScalarIndexParams>()
.ok_or_else(|| Error::Index {
message: "Scalar index type must take a ScalarIndexParams".to_string(),
location: location!(),
})?;
build_scalar_index(self, column, &index_id.to_string(), params).await?;
}
(IndexType::Vector, LANCE_VECTOR_INDEX) => {
// Vector index params.
Expand Down
13 changes: 8 additions & 5 deletions rust/lance/src/index/prefilter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub struct DatasetPreFilter {
pub(super) deleted_ids: Option<Arc<SharedPrerequisite<Arc<RowIdMask>>>>,
pub(super) filtered_ids: Option<Arc<SharedPrerequisite<RowIdMask>>>,
// When the tasks are finished this is the combined filter
pub(super) final_mask: Mutex<OnceCell<RowIdMask>>,
pub(super) final_mask: Mutex<OnceCell<Arc<RowIdMask>>>,
}

impl DatasetPreFilter {
Expand Down Expand Up @@ -233,7 +233,7 @@ impl PreFilter for DatasetPreFilter {
if let Some(deleted_ids) = &self.deleted_ids {
combined = combined & (*deleted_ids.get_ready()).clone();
}
combined
Arc::new(combined)
});

Ok(())
Expand All @@ -251,11 +251,14 @@ impl PreFilter for DatasetPreFilter {
/// This method must be called after `wait_for_ready`
#[instrument(level = "debug", skip_all)]
fn filter_row_ids<'a>(&self, row_ids: Box<dyn Iterator<Item = &'a u64> + 'a>) -> Vec<u64> {
let final_mask = self.final_mask.lock().unwrap();
final_mask
let final_mask = self
.final_mask
.lock()
.unwrap()
.get()
.expect("filter_row_ids called without call to wait_for_ready")
.selected_indices(row_ids)
.clone();
final_mask.selected_indices(row_ids)
}
}

Expand Down
Loading

0 comments on commit 49de38e

Please sign in to comment.