From e89c7adff8876de9d41c4f2b9f1ee55200bfddb4 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jul 2024 18:49:19 +0800 Subject: [PATCH 1/5] perf: improve inverted index performance - add stopword filter to avoid words that occur everywhere causing very bad performance - store tokens, inverted list and docs in `HashMap` - use docs in `LargeStringArray` cause single doc could be large then the total length could be over `i32::MAX` - more effecient updating for inverted list Signed-off-by: BubbleCal --- rust/lance-index/Cargo.toml | 5 + rust/lance-index/benches/inverted.rs | 96 ++++++++ rust/lance-index/src/scalar/inverted.rs | 298 +++++++++++++----------- 3 files changed, 263 insertions(+), 136 deletions(-) create mode 100644 rust/lance-index/benches/inverted.rs diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 348fca7478..d9e2ee71f6 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -63,6 +63,7 @@ lance-datagen.workspace = true lance-testing.workspace = true tempfile.workspace = true datafusion-sql.workspace = true +random_word = { version = "0.4.3", features = ["en"] } [build-dependencies] prost-build.workspace = true @@ -90,3 +91,7 @@ harness = false [[bench]] name = "sq" harness = false + +[[bench]] +name = "inverted" +harness = false diff --git a/rust/lance-index/benches/inverted.rs b/rust/lance-index/benches/inverted.rs new file mode 100644 index 0000000000..f87862ce31 --- /dev/null +++ b/rust/lance-index/benches/inverted.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmark of HNSW graph. +//! +//! + +use std::{sync::Arc, time::Duration}; + +use arrow_array::{LargeStringArray, RecordBatch, UInt64Array}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use futures::stream; +use itertools::Itertools; +use lance_core::ROW_ID; +use lance_index::scalar::inverted::InvertedIndex; +use lance_index::scalar::lance_format::LanceIndexStore; +use lance_index::scalar::{ScalarIndex, ScalarQuery}; +use lance_io::object_store::ObjectStore; +use object_store::path::Path; +#[cfg(target_os = "linux")] +use pprof::criterion::{Output, PProfProfiler}; + +fn bench_inverted(c: &mut Criterion) { + const TOTAL: usize = 30_000_000; + + let rt = tokio::runtime::Runtime::new().unwrap(); + + let tempdir = tempfile::tempdir().unwrap(); + let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); + let store = Arc::new(LanceIndexStore::new(ObjectStore::local(), index_dir, None)); + + let invert_index = InvertedIndex::default(); + // generate 2000 different tokens + let tokens = random_word::all(random_word::Lang::En); + let row_id_col = Arc::new(UInt64Array::from( + (0..TOTAL).map(|i| i as u64).collect_vec(), + )); + let docs = (0..TOTAL) + .map(|_| { + let num_words = rand::random::() % 100 + 1; + let doc = (0..num_words) + .map(|_| tokens[rand::random::() % tokens.len()]) + .collect::>(); + doc.join(" ") + }) + .collect_vec(); + let doc_col = Arc::new(LargeStringArray::from(docs)); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("doc", arrow_schema::DataType::LargeUtf8, false), + arrow_schema::Field::new(ROW_ID, arrow_schema::DataType::UInt64, false), + ]) + .into(), + vec![doc_col.clone(), row_id_col.clone()], + ) + .unwrap(); + let stream = RecordBatchStreamAdapter::new(batch.schema(), stream::iter(vec![Ok(batch)])); + let stream = Box::pin(stream); + + rt.block_on(async { invert_index.update(stream, store.as_ref()).await.unwrap() }); + let invert_index = rt.block_on(InvertedIndex::load(store)).unwrap(); + + c.bench_function(format!("invert({TOTAL})").as_str(), |b| { + b.to_async(&rt).iter(|| async { + black_box( + invert_index + .search(&ScalarQuery::FullTextSearch(vec![tokens + [rand::random::() % tokens.len()] + .to_owned()])) + .await + .unwrap(), + ); + }) + }); +} + +#[cfg(target_os = "linux")] +criterion_group!( + name=benches; + config = Criterion::default() + .measurement_time(Duration::from_secs(10)) + .sample_size(10) + .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = bench_inverted); + +// Non-linux version does not support pprof. +#[cfg(not(target_os = "linux"))] +criterion_group!( + name=benches; + config = Criterion::default() + .measurement_time(Duration::from_secs(10)) + .sample_size(10); + targets = bench_inverted); + +criterion_main!(benches); diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 53ce44f77a..c57ac06376 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -4,9 +4,9 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow::array::{AsArray, ListBuilder, UInt64Builder}; -use arrow::datatypes; -use arrow_array::{ArrayRef, RecordBatch, StringArray, UInt32Array, UInt64Array}; +use arrow::array::{AsArray, LargeStringBuilder, ListBuilder, UInt32Builder, UInt64Builder}; +use arrow::datatypes::{self, UInt64Type}; +use arrow_array::{ArrayRef, RecordBatch, UInt32Array, UInt64Array}; use arrow_schema::{DataType, Field}; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; @@ -16,6 +16,7 @@ use itertools::Itertools; use lance_core::{Error, Result, ROW_ID}; use roaring::RoaringBitmap; use snafu::{location, Location}; +use tantivy::tokenizer::TokenFilter; use crate::vector::graph::OrderedFloat; use crate::Index; @@ -60,19 +61,15 @@ impl InvertedIndex { token_ids .into_iter() .filter_map(|token| self.invert_list.retrieve(token)) - .for_each(|(row_ids, freq)| { + .for_each(|row_freq| { // TODO: this can be optimized by parallelizing the calculation - row_ids - .iter() - .zip(freq.iter()) - .for_each(|(&row_id, &freq)| { - let freq = freq as f32; - let bm25 = bm25.entry(row_id).or_insert(0.0); - *bm25 += self.idf(row_ids.len()) * freq * (K1 + 1.0) - / (freq - + K1 * (1.0 - B - + B * self.docs.num_tokens[row_id as usize] as f32 / avgdl)); - }); + row_freq.iter().for_each(|(row_id, freq)| { + let row_id = *row_id; + let freq = *freq as f32; + let bm25 = bm25.entry(row_id).or_insert(0.0); + *bm25 += self.idf(row_freq.len()) * freq * (K1 + 1.0) + / (freq + K1 * (1.0 - B + B * self.docs.num_tokens(row_id) as f32 / avgdl)); + }); }); bm25.into_iter() @@ -82,7 +79,7 @@ impl InvertedIndex { #[inline] fn idf(&self, nq: usize) -> f32 { - let num_docs = self.docs.row_ids.len() as f32; + let num_docs = self.docs.len() as f32; ((num_docs - nq as f32 + 0.5) / (nq as f32 + 0.5) + 1.0).ln() } } @@ -107,7 +104,7 @@ impl Index for InvertedIndex { fn statistics(&self) -> Result { Ok(serde_json::json!({ "num_tokens": self.tokens.tokens.len(), - "num_docs": self.docs.row_ids.len(), + "num_docs": self.docs.token_count.len(), })) } @@ -179,25 +176,32 @@ impl ScalarIndex for InvertedIndex { let mut token_set = self.tokens.clone(); let mut invert_list = self.invert_list.clone(); let mut docs = self.docs.clone(); + let stopword_filter = + tantivy::tokenizer::StopWordFilter::new(tantivy::tokenizer::Language::English).unwrap(); let mut tokenizer = tantivy::tokenizer::TextAnalyzer::builder( - tantivy::tokenizer::SimpleTokenizer::default(), + stopword_filter.transform(tantivy::tokenizer::SimpleTokenizer::default()), ) .build(); let mut stream = new_data.peekable(); while let Some(batch) = stream.try_next().await? { - let doc_col = batch.column(0).as_string::(); + let doc_col = batch.column(0).as_string::(); let row_id_col = batch[ROW_ID].as_primitive::(); for (doc, row_id) in doc_col.iter().zip(row_id_col.iter()) { let doc = doc.unwrap(); let row_id = row_id.unwrap(); let mut token_stream = tokenizer.token_stream(doc); + let mut row_token_cnt = HashMap::new(); let mut token_cnt = 0; while let Some(token) = token_stream.next() { - let token_id = token_set.add(token.text.clone()); - invert_list.add(token_id, row_id); + let token_id = token_set.add(token.text.to_owned()); + row_token_cnt + .entry(token_id) + .and_modify(|cnt| *cnt += 1) + .or_insert(1); token_cnt += 1; } + invert_list.add(row_token_cnt, row_id); docs.add(row_id, token_cnt); } } @@ -233,19 +237,29 @@ impl ScalarIndex for InvertedIndex { // it also records the frequency of each token #[derive(Debug, Clone, Default, DeepSizeOf)] struct TokenSet { - tokens: Vec, - ids: Vec, - frequencies: Vec, + // token -> (token_id, frequency) + tokens: HashMap, + next_id: u32, } impl TokenSet { fn to_batch(&self) -> Result { - let token_col = StringArray::from(self.tokens.clone()); - let token_id_col = UInt32Array::from(self.ids.clone()); - let frequency_col = UInt64Array::from(self.frequencies.clone()); + let mut tokens_builder = LargeStringBuilder::with_capacity(self.tokens.len(), 32); + let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len()); + let mut frequency_builder = UInt64Builder::with_capacity(self.tokens.len()); + self.tokens + .iter() + .for_each(|(token, (token_id, frequency))| { + tokens_builder.append_value(token); + token_id_builder.append_value(*token_id); + frequency_builder.append_value(*frequency); + }); + let token_col = tokens_builder.finish(); + let token_id_col = token_id_builder.finish(); + let frequency_col = frequency_builder.finish(); let schema = arrow_schema::Schema::new(vec![ - arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false), + arrow_schema::Field::new(TOKEN_COL, DataType::LargeUtf8, false), arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false), arrow_schema::Field::new(FREQUENCY_COL, DataType::UInt64, false), ]); @@ -262,51 +276,53 @@ impl TokenSet { } async fn load(reader: Arc) -> Result { - let mut tokens = Vec::new(); - let mut ids = Vec::new(); - let mut frequencies = Vec::new(); + let mut tokens = HashMap::new(); + let mut next_id = 0; for i in 0..reader.num_batches().await { let batch = reader.read_record_batch(i).await?; - let token_col = batch[TOKEN_COL].as_string::(); + let token_col = batch[TOKEN_COL].as_string::(); let token_id_col = batch[TOKEN_ID_COL].as_primitive::(); let frequency_col = batch[FREQUENCY_COL].as_primitive::(); - tokens.extend(token_col.iter().map(|v| v.unwrap().to_owned())); - ids.extend(token_id_col.iter().map(|v| v.unwrap())); - frequencies.extend(frequency_col.iter().map(|v| v.unwrap())); + for ((token, token_id), frequency) in token_col + .iter() + .zip(token_id_col.iter()) + .zip(frequency_col.iter()) + { + let token = token.unwrap(); + let token_id = token_id.unwrap(); + let frequency = frequency.unwrap(); + tokens.insert(token.to_owned(), (token_id, frequency)); + next_id = next_id.max(token_id + 1); + } } - Ok(Self { - tokens, - ids, - frequencies, - }) + Ok(Self { tokens, next_id }) } fn add(&mut self, token: String) -> u32 { - let token_id = match self.get(&token) { - Some(token_id) => token_id, - None => self.next_id(), - }; + let next_id = self.next_id(); + let token_id = self + .tokens + .entry(token) + .and_modify(|(_, freq)| *freq += 1) + .or_insert((next_id, 1)) + .0; // add token if it doesn't exist - if token_id == self.next_id() { - self.tokens.push(token); - self.ids.push(token_id); - self.frequencies.push(0); + if token_id == next_id { + self.next_id += 1; } - self.frequencies[token_id as usize] += 1; token_id } - fn get(&self, token: &String) -> Option { - let pos = self.tokens.binary_search(token).ok()?; - Some(self.ids[pos]) + fn get(&self, token: &str) -> Option { + self.tokens.get(token).map(|(token_id, _)| *token_id) } fn next_id(&self) -> u32 { - self.ids.last().map(|id| id + 1).unwrap_or(0) + self.next_id } } @@ -314,34 +330,35 @@ impl TokenSet { // it's used to retrieve the documents that contain a token #[derive(Debug, Clone, Default, DeepSizeOf)] struct InvertedList { - tokens: Vec, - row_ids_list: Vec>, - frequencies_list: Vec>, + inverted_list: HashMap>, + // tokens: Vec, + // row_ids_list: Vec>, + // frequencies_list: Vec>, } impl InvertedList { fn to_batch(&self) -> Result { - let token_id_col = UInt32Array::from(self.tokens.clone()); - let mut row_ids_col = - ListBuilder::with_capacity(UInt64Builder::new(), self.row_ids_list.len()); - let mut frequencies_col = - ListBuilder::with_capacity(UInt64Builder::new(), self.frequencies_list.len()); - - for row_ids in &self.row_ids_list { - let builder = row_ids_col.values(); - for row_id in row_ids { - builder.append_value(*row_id); + let mut token_id_builder = UInt32Builder::with_capacity(self.inverted_list.len()); + let mut row_ids_list_builder = + ListBuilder::with_capacity(UInt64Builder::new(), self.inverted_list.len()); + let mut frequencies_list_builder = + ListBuilder::with_capacity(UInt64Builder::new(), self.inverted_list.len()); + + for (token_id, list) in &self.inverted_list { + token_id_builder.append_value(*token_id); + let row_ids_builder = row_ids_list_builder.values(); + let frequencies_builder = frequencies_list_builder.values(); + for (row_id, frequency) in list { + row_ids_builder.append_value(*row_id); + frequencies_builder.append_value(*frequency); } - row_ids_col.append(true); + row_ids_list_builder.append(true); + frequencies_list_builder.append(true); } - for frequencies in &self.frequencies_list { - let builder = frequencies_col.values(); - for frequency in frequencies { - builder.append_value(*frequency); - } - frequencies_col.append(true); - } + let token_id_col = token_id_builder.finish(); + let row_ids_col = row_ids_list_builder.finish(); + let frequencies_col = frequencies_list_builder.finish(); let schema = arrow_schema::Schema::new(vec![ arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false), @@ -361,71 +378,56 @@ impl InvertedList { Arc::new(schema), vec![ Arc::new(token_id_col) as ArrayRef, - Arc::new(row_ids_col.finish()) as ArrayRef, - Arc::new(frequencies_col.finish()) as ArrayRef, + Arc::new(row_ids_col) as ArrayRef, + Arc::new(frequencies_col) as ArrayRef, ], )?; Ok(batch) } async fn load(reader: Arc) -> Result { - let mut tokens = Vec::new(); - let mut row_ids_list = Vec::new(); - let mut frequencies_list = Vec::new(); + let mut inverted_list = HashMap::new(); for i in 0..reader.num_batches().await { let batch = reader.read_record_batch(i).await?; let token_col = batch[TOKEN_ID_COL].as_primitive::(); let row_ids_col = batch[ROW_ID].as_list::(); let frequencies_col = batch[FREQUENCY_COL].as_list::(); - tokens.extend(token_col.iter().map(|v| v.unwrap())); - for value in row_ids_col.iter() { - let value = value.unwrap(); - let row_ids = value - .as_primitive::() - .values() + for ((token_id, row_ids), frequencies) in token_col + .iter() + .zip(row_ids_col.iter()) + .zip(frequencies_col.iter()) + { + let token_id = token_id.unwrap(); + let row_ids = row_ids.unwrap(); + let frequencies = frequencies.unwrap(); + let row_ids = row_ids.as_primitive::().values(); + let frequencies = frequencies.as_primitive::().values(); + let list = row_ids .iter() .cloned() + .zip(frequencies.iter().cloned()) .collect_vec(); - row_ids_list.push(row_ids); - } - for value in frequencies_col.iter() { - let value = value.unwrap(); - let frequencies = value - .as_primitive::() - .values() - .iter() - .cloned() - .collect_vec(); - frequencies_list.push(frequencies); + inverted_list.insert(token_id, list); } } - Ok(Self { - tokens, - row_ids_list, - frequencies_list, - }) + Ok(Self { inverted_list }) } - fn add(&mut self, token_id: u32, row_id: u64) { - let pos = match self.tokens.binary_search(&token_id) { - Ok(pos) => pos, - Err(pos) => { - self.tokens.insert(pos, token_id); - self.row_ids_list.insert(pos, Vec::new()); - self.frequencies_list.insert(pos, Vec::new()); - pos - } - }; - - self.row_ids_list[pos].push(row_id); - self.frequencies_list[pos].push(1); + // for efficiency, we don't check if the row_id exists + // we assume that the row_id is unique and doesn't exist in the list + fn add(&mut self, token_cnt: HashMap, row_id: u64) { + for (token_id, freq) in token_cnt { + let list = self.inverted_list.entry(token_id).or_default(); + list.push((row_id, freq)); + } } - fn retrieve(&self, token_id: u32) -> Option<(&[u64], &[u64])> { - let pos = self.tokens.binary_search(&token_id).ok()?; - Some((&self.row_ids_list[pos], &self.frequencies_list[pos])) + fn retrieve(&self, token_id: u32) -> Option<&[(u64, u64)]> { + self.inverted_list + .get(&token_id) + .map(|list| list.as_slice()) } } @@ -433,19 +435,29 @@ impl InvertedList { // It's used to sort the documents by the bm25 score #[derive(Debug, Clone, Default, DeepSizeOf)] struct DocSet { - row_ids: Vec, - num_tokens: Vec, + // row id -> num tokens + token_count: HashMap, + // row_ids: Vec, + // num_tokens: Vec, total_tokens: u64, } impl DocSet { + fn len(&self) -> usize { + self.token_count.len() + } + + fn is_empty(&self) -> bool { + self.token_count.is_empty() + } + fn average_length(&self) -> f32 { - self.total_tokens as f32 / self.row_ids.len() as f32 + self.total_tokens as f32 / self.token_count.len() as f32 } fn to_batch(&self) -> Result { - let row_id_col = UInt64Array::from(self.row_ids.clone()); - let num_tokens_col = UInt32Array::from(self.num_tokens.clone()); + let row_id_col = UInt64Array::from_iter_values(self.token_count.keys().cloned()); + let num_tokens_col = UInt32Array::from_iter_values(self.token_count.values().cloned()); let schema = arrow_schema::Schema::new(vec![ arrow_schema::Field::new(ROW_ID, DataType::UInt64, false), @@ -463,29 +475,34 @@ impl DocSet { } async fn load(reader: Arc) -> Result { - let mut row_ids = Vec::new(); - let mut num_tokens = Vec::new(); + let mut token_count = HashMap::new(); let mut total_tokens = 0; for i in 0..reader.num_batches().await { let batch = reader.read_record_batch(i).await?; let row_id_col = batch[ROW_ID].as_primitive::(); let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::(); - row_ids.extend(row_id_col.iter().map(|v| v.unwrap())); - num_tokens.extend(num_tokens_col.iter().map(|v| v.unwrap())); - total_tokens += num_tokens.iter().map(|v| *v as u64).sum::(); + for (row_id, num_tokens) in row_id_col.iter().zip(num_tokens_col.iter()) { + let row_id = row_id.unwrap(); + let num_tokens = num_tokens.unwrap(); + token_count.insert(row_id, num_tokens); + total_tokens += num_tokens as u64; + } } Ok(Self { - row_ids, - num_tokens, + token_count, total_tokens, }) } + fn num_tokens(&self, row_id: u64) -> u32 { + self.token_count.get(&row_id).cloned().unwrap_or_default() + } + fn add(&mut self, row_id: u64, num_tokens: u32) { - self.row_ids.push(row_id); - self.num_tokens.push(num_tokens); + self.token_count.insert(row_id, num_tokens); + self.total_tokens += num_tokens as u64; } } @@ -493,7 +510,7 @@ impl DocSet { mod tests { use std::sync::Arc; - use arrow_array::{ArrayRef, RecordBatch, StringArray, UInt64Array}; + use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, UInt64Array}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use futures::stream; use lance_io::object_store::ObjectStore; @@ -510,10 +527,15 @@ mod tests { let invert_index = super::InvertedIndex::default(); let row_id_col = UInt64Array::from(vec![0, 1, 2, 3]); - let doc_col = StringArray::from(vec!["a b c", "a b", "a c", "b c"]); + let doc_col = LargeStringArray::from(vec![ + "lance database search", + "lance database", + "lance search", + "database search", + ]); let batch = RecordBatch::try_new( arrow_schema::Schema::new(vec![ - arrow_schema::Field::new("doc", arrow_schema::DataType::Utf8, false), + arrow_schema::Field::new("doc", arrow_schema::DataType::LargeUtf8, false), arrow_schema::Field::new(super::ROW_ID, arrow_schema::DataType::UInt64, false), ]) .into(), @@ -533,7 +555,9 @@ mod tests { let invert_index = super::InvertedIndex::load(Arc::new(store)).await.unwrap(); let row_ids = invert_index - .search(&super::ScalarQuery::FullTextSearch(vec!["a".to_string()])) + .search(&super::ScalarQuery::FullTextSearch(vec![ + "lance".to_string() + ])) .await .unwrap(); assert_eq!(row_ids.len(), 3); @@ -542,7 +566,9 @@ mod tests { assert!(row_ids.values().contains(&2)); let row_ids = invert_index - .search(&super::ScalarQuery::FullTextSearch(vec!["b".to_string()])) + .search(&super::ScalarQuery::FullTextSearch(vec![ + "database".to_string() + ])) .await .unwrap(); assert_eq!(row_ids.len(), 3); From 15b40b0421cacf2f70ec16585df9a1f2b02ecef0 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jul 2024 19:53:48 +0800 Subject: [PATCH 2/5] fmt Signed-off-by: BubbleCal --- rust/lance-index/src/scalar/inverted.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index c57ac06376..7553afc8b9 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -447,10 +447,6 @@ impl DocSet { self.token_count.len() } - fn is_empty(&self) -> bool { - self.token_count.is_empty() - } - fn average_length(&self) -> f32 { self.total_tokens as f32 / self.token_count.len() as f32 } From 8686d58ff734230ea35b55ea208d554071c78c14 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 8 Jul 2024 20:11:02 +0800 Subject: [PATCH 3/5] smaller dataset Signed-off-by: BubbleCal --- rust/lance-index/benches/inverted.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lance-index/benches/inverted.rs b/rust/lance-index/benches/inverted.rs index f87862ce31..0f20624b33 100644 --- a/rust/lance-index/benches/inverted.rs +++ b/rust/lance-index/benches/inverted.rs @@ -22,7 +22,7 @@ use object_store::path::Path; use pprof::criterion::{Output, PProfProfiler}; fn bench_inverted(c: &mut Criterion) { - const TOTAL: usize = 30_000_000; + const TOTAL: usize = 1_000_000; let rt = tokio::runtime::Runtime::new().unwrap(); From ff61924be6d494846035f4b42e16e19afb050b12 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 9 Jul 2024 11:33:31 +0800 Subject: [PATCH 4/5] fix comments Signed-off-by: BubbleCal --- rust/lance-index/src/scalar/inverted.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 7553afc8b9..05854c1afc 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -51,6 +51,7 @@ impl InvertedIndex { // search the documents that contain the query // return the row ids of the documents sorted by bm25 score + // ref: https://en.wikipedia.org/wiki/Okapi_BM25 fn bm25_search(&self, token_ids: Vec) -> Vec<(u64, f32)> { const K1: f32 = 1.2; const B: f32 = 0.75; @@ -284,14 +285,12 @@ impl TokenSet { let token_id_col = batch[TOKEN_ID_COL].as_primitive::(); let frequency_col = batch[FREQUENCY_COL].as_primitive::(); - for ((token, token_id), frequency) in token_col + for ((token, &token_id), &frequency) in token_col .iter() - .zip(token_id_col.iter()) - .zip(frequency_col.iter()) + .zip(token_id_col.values().iter()) + .zip(frequency_col.values().iter()) { let token = token.unwrap(); - let token_id = token_id.unwrap(); - let frequency = frequency.unwrap(); tokens.insert(token.to_owned(), (token_id, frequency)); next_id = next_id.max(token_id + 1); } @@ -393,12 +392,12 @@ impl InvertedList { let row_ids_col = batch[ROW_ID].as_list::(); let frequencies_col = batch[FREQUENCY_COL].as_list::(); - for ((token_id, row_ids), frequencies) in token_col + for ((&token_id, row_ids), frequencies) in token_col + .values() .iter() .zip(row_ids_col.iter()) .zip(frequencies_col.iter()) { - let token_id = token_id.unwrap(); let row_ids = row_ids.unwrap(); let frequencies = frequencies.unwrap(); let row_ids = row_ids.as_primitive::().values(); @@ -478,9 +477,11 @@ impl DocSet { let row_id_col = batch[ROW_ID].as_primitive::(); let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::(); - for (row_id, num_tokens) in row_id_col.iter().zip(num_tokens_col.iter()) { - let row_id = row_id.unwrap(); - let num_tokens = num_tokens.unwrap(); + for (&row_id, &num_tokens) in row_id_col + .values() + .iter() + .zip(num_tokens_col.values().iter()) + { token_count.insert(row_id, num_tokens); total_tokens += num_tokens as u64; } From c646e3dbd3416027ca3adf5024d458e6afe3d351 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 12 Jul 2024 14:53:05 +0800 Subject: [PATCH 5/5] fix Signed-off-by: BubbleCal --- rust/lance-index/benches/inverted.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/lance-index/benches/inverted.rs b/rust/lance-index/benches/inverted.rs index 0f20624b33..ef73caa1c3 100644 --- a/rust/lance-index/benches/inverted.rs +++ b/rust/lance-index/benches/inverted.rs @@ -15,7 +15,7 @@ use itertools::Itertools; use lance_core::ROW_ID; use lance_index::scalar::inverted::InvertedIndex; use lance_index::scalar::lance_format::LanceIndexStore; -use lance_index::scalar::{ScalarIndex, ScalarQuery}; +use lance_index::scalar::{SargableQuery, ScalarIndex}; use lance_io::object_store::ObjectStore; use object_store::path::Path; #[cfg(target_os = "linux")] @@ -65,7 +65,7 @@ fn bench_inverted(c: &mut Criterion) { b.to_async(&rt).iter(|| async { black_box( invert_index - .search(&ScalarQuery::FullTextSearch(vec![tokens + .search(&SargableQuery::FullTextSearch(vec![tokens [rand::random::() % tokens.len()] .to_owned()])) .await