diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 2a9b3c957d..123805ce6b 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -339,7 +339,7 @@ pub async fn train_bitmap_index( data_source: Box, index_store: &dyn IndexStore, ) -> Result<()> { - let batches_source = data_source.scan_ordered_chunks(4096).await?; + let batches_source = data_source.scan_unordered_chunks(4096).await?; // mapping from item to list of the row ids where it is present let dictionary: HashMap = HashMap::new(); diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 3257b936e4..e01718344f 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -1064,6 +1064,18 @@ pub trait TrainingSource: Send { self: Box, chunk_size: u32, ) -> Result; + + /// Returns a stream of batches + /// + /// Each batch should have chunk_size rows + /// + /// The schema for the batch is slightly flexible. + /// The first column may have any name or type, these are the values to index + /// The second column must be the row ids which must be UInt64Type + async fn scan_unordered_chunks( + self: Box, + chunk_size: u32, + ) -> Result; } /// Train a btree index from a stream of sorted page-size batches of values and row ids @@ -1153,6 +1165,14 @@ impl TrainingSource for BTreeUpdater { )?; Ok(chunk_concat_stream(unchunked, chunk_size as usize)) } + + async fn scan_unordered_chunks( + self: Box, + _chunk_size: u32, + ) -> Result { + // BTree indices will never use unordered scans + unimplemented!() + } } /// A stream that reads the original training data back out of the index diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index e55de8f9cb..2b7a33b5fe 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -1,12 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc}; +use std::{any::Any, collections::HashMap, fmt::Debug, pin::Pin, sync::Arc}; use arrow::array::AsArray; use arrow_array::{Array, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; use async_trait::async_trait; +use datafusion::execution::RecordBatchStream; use datafusion::physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; use datafusion_common::ScalarValue; use deepsize::DeepSizeOf; @@ -271,17 +272,31 @@ impl TrainingSource for UnnestTrainingSource { chunk_size: u32, ) -> Result { let source = self.source.scan_ordered_chunks(chunk_size).await?; - let unnest_schema = unnest_schema(source.schema().as_ref()); - let unnest_schema_copy = unnest_schema.clone(); - let source = source.try_filter_map(move |batch| { - std::future::ready(Some(unnest_batch(batch, unnest_schema.clone())).transpose()) - }); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - unnest_schema_copy.clone(), - source, - ))) + unnest_chunks(source) } + + async fn scan_unordered_chunks( + self: Box, + chunk_size: u32, + ) -> Result { + let source = self.source.scan_unordered_chunks(chunk_size).await?; + unnest_chunks(source) + } +} + +fn unnest_chunks( + source: Pin>, +) -> Result { + let unnest_schema = unnest_schema(source.schema().as_ref()); + let unnest_schema_copy = unnest_schema.clone(); + let source = source.try_filter_map(move |batch| { + std::future::ready(Some(unnest_batch(batch, unnest_schema.clone())).transpose()) + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + unnest_schema_copy.clone(), + source, + ))) } /// Trains a new label list index diff --git a/rust/lance-index/src/scalar/lance_format.rs b/rust/lance-index/src/scalar/lance_format.rs index e08cffc8bd..ae8a2ac82a 100644 --- a/rust/lance-index/src/scalar/lance_format.rs +++ b/rust/lance-index/src/scalar/lance_format.rs @@ -336,6 +336,13 @@ mod tests { ) -> Result { Ok(self.data) } + + async fn scan_unordered_chunks( + self: Box, + _chunk_size: u32, + ) -> Result { + Ok(self.data) + } } async fn train_index( diff --git a/rust/lance/benches/scalar_index.rs b/rust/lance/benches/scalar_index.rs index 63949396fd..9bdd734049 100644 --- a/rust/lance/benches/scalar_index.rs +++ b/rust/lance/benches/scalar_index.rs @@ -50,6 +50,13 @@ impl TrainingSource for BenchmarkDataSource { ) -> Result { Ok(reader_to_stream(Box::new(Self::test_data()))) } + + async fn scan_unordered_chunks( + self: Box, + _chunk_size: u32, + ) -> Result { + Ok(reader_to_stream(Box::new(Self::test_data()))) + } } impl BenchmarkFixture { diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index da5281a76a..dab9afe1a1 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -38,22 +38,43 @@ impl TrainingSource for TrainingRequest { async fn scan_ordered_chunks( self: Box, chunk_size: u32, + ) -> Result { + self.scan_chunks(chunk_size, true).await + } + + async fn scan_unordered_chunks( + self: Box, + chunk_size: u32, + ) -> Result { + self.scan_chunks(chunk_size, false).await + } +} + +impl TrainingRequest { + async fn scan_chunks( + self: Box, + chunk_size: u32, + sort: bool, ) -> Result { let mut scan = self.dataset.scan(); + + let ordering = match sort { + true => Some(vec![ColumnOrdering::asc_nulls_first(self.column.clone())]), + false => None, + }; + let scan = scan .with_row_id() - .order_by(Some(vec![ColumnOrdering::asc_nulls_first( - self.column.clone(), - )]))? + .order_by(ordering)? .project(&[&self.column])?; - let ordered_batches = scan + let batches = scan .try_into_dfstream(LanceExecutionOptions { use_spilling: true, ..Default::default() }) .await?; - Ok(chunk_concat_stream(ordered_batches, chunk_size as usize)) + Ok(chunk_concat_stream(batches, chunk_size as usize)) } }