diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index a7a533fd8a..5c245ed313 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -1,4 +1,5 @@ name: Build and Run Java JNI Tests + on: push: branches: @@ -8,6 +9,7 @@ on: - java/** - rust/** - .github/workflows/java.yml + env: # This env var is used by Swatinem/rust-cache@v2 for the cache # key, so we set it to make sure it is always consistent. @@ -20,74 +22,75 @@ env: # CI builds are faster with incremental disabled. CARGO_INCREMENTAL: "0" CARGO_BUILD_JOBS: "1" + jobs: - linux-build: + rust-clippy-fmt: runs-on: ubuntu-22.04 - name: ubuntu-22.04 + Java 11 & 17 + name: Rust Clippy and Fmt Check defaults: run: - working-directory: ./java + working-directory: ./java/core/lance-jni steps: - name: Checkout repository uses: actions/checkout@v4 - uses: Swatinem/rust-cache@v2 with: - workspaces: java/java-jni + workspaces: java/core/lance-jni + - name: Install dependencies + run: | + sudo apt update + sudo apt install -y protobuf-compiler libssl-dev - name: Run cargo fmt run: cargo fmt --check - working-directory: ./java/core/lance-jni + - name: Rust Clippy + run: cargo clippy --all-targets -- -D warnings + + build-and-test-java: + runs-on: ubuntu-22.04 + strategy: + matrix: + java-version: [8, 11, 17] + name: Build and Test with Java ${{ matrix.java-version }} + defaults: + run: + working-directory: ./java + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + with: + workspaces: java/core/lance-jni - name: Install dependencies run: | sudo apt update sudo apt install -y protobuf-compiler libssl-dev - - name: Install Java 17 - uses: actions/setup-java@v4 - with: - distribution: temurin - java-version: 17 - cache: "maven" - - run: echo "JAVA_17=$JAVA_HOME" >> $GITHUB_ENV - - name: Install Java 8 + - name: Set up Java ${{ matrix.java-version }} uses: actions/setup-java@v4 with: distribution: temurin - java-version: 8 + java-version: ${{ matrix.java-version }} cache: "maven" - - run: echo "JAVA_8=$JAVA_HOME" >> $GITHUB_ENV - - name: Install Java 11 - uses: actions/setup-java@v4 - with: - distribution: temurin - java-version: 11 - cache: "maven" - - name: Java Style Check - run: mvn checkstyle:check - - name: Rust Clippy - working-directory: java/core/lance-jni - run: cargo clippy --all-targets -- -D warnings - - name: Running tests with Java 11 - run: mvn clean test - - name: Running tests with Java 8 - run: JAVA_HOME=$JAVA_8 mvn clean test - - name: Running tests with Java 17 + - name: Running tests with Java ${{ matrix.java-version }} run: | - export JAVA_TOOL_OPTIONS="$JAVA_TOOL_OPTIONS \ - -XX:+IgnoreUnrecognizedVMOptions \ - --add-opens=java.base/java.lang=ALL-UNNAMED \ - --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \ - --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \ - --add-opens=java.base/java.io=ALL-UNNAMED \ - --add-opens=java.base/java.net=ALL-UNNAMED \ - --add-opens=java.base/java.nio=ALL-UNNAMED \ - --add-opens=java.base/java.util=ALL-UNNAMED \ - --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \ - --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \ - --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED \ - --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \ - --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \ - --add-opens=java.base/sun.security.action=ALL-UNNAMED \ - --add-opens=java.base/sun.util.calendar=ALL-UNNAMED \ - --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED \ - -Djdk.reflect.useDirectMethodHandle=false \ - -Dio.netty.tryReflectionSetAccessible=true" - JAVA_HOME=$JAVA_17 mvn clean test + if [ "${{ matrix.java-version }}" == "17" ]; then + export JAVA_TOOL_OPTIONS="$JAVA_TOOL_OPTIONS \ + -XX:+IgnoreUnrecognizedVMOptions \ + --add-opens=java.base/java.lang=ALL-UNNAMED \ + --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \ + --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \ + --add-opens=java.base/java.io=ALL-UNNAMED \ + --add-opens=java.base/java.net=ALL-UNNAMED \ + --add-opens=java.base/java.nio=ALL-UNNAMED \ + --add-opens=java.base/java.util=ALL-UNNAMED \ + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \ + --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \ + --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED \ + --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \ + --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \ + --add-opens=java.base/sun.security.action=ALL-UNNAMED \ + --add-opens=java.base/sun.util.calendar=ALL-UNNAMED \ + --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED \ + -Djdk.reflect.useDirectMethodHandle=false \ + -Dio.netty.tryReflectionSetAccessible=true" + fi + mvn clean test diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 743170c619..df6e712cfa 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -110,7 +110,8 @@ jobs: run: sudo rm -rf target/wheels linux-arm: timeout-minutes: 45 - runs-on: warp-ubuntu-latest-arm64-4x + #runs-on: warp-ubuntu-latest-arm64-4x + runs-on: buildjet-4vcpu-ubuntu-2204-arm name: Python Linux 3.${{ matrix.python-minor-version }} ARM strategy: matrix: diff --git a/Cargo.toml b/Cargo.toml index 4274bec8d9..b739f6e94c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.12.4" +version = "0.13.1" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -40,23 +40,23 @@ categories = [ "development-tools", "science", ] -rust-version = "1.75" +rust-version = "1.78" [workspace.dependencies] -lance = { version = "=0.12.4", path = "./rust/lance" } -lance-arrow = { version = "=0.12.4", path = "./rust/lance-arrow" } -lance-core = { version = "=0.12.4", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.12.4", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.12.4", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.12.4", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.12.4", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.12.4", path = "./rust/lance-file" } -lance-index = { version = "=0.12.4", path = "./rust/lance-index" } -lance-io = { version = "=0.12.4", path = "./rust/lance-io" } -lance-linalg = { version = "=0.12.4", path = "./rust/lance-linalg" } -lance-table = { version = "=0.12.4", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.12.4", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.12.4", path = "./rust/lance-testing" } +lance = { version = "=0.13.1", path = "./rust/lance" } +lance-arrow = { version = "=0.13.1", path = "./rust/lance-arrow" } +lance-core = { version = "=0.13.1", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.13.1", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.13.1", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.13.1", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.13.1", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.13.1", path = "./rust/lance-file" } +lance-index = { version = "=0.13.1", path = "./rust/lance-index" } +lance-io = { version = "=0.13.1", path = "./rust/lance-io" } +lance-linalg = { version = "=0.13.1", path = "./rust/lance-linalg" } +lance-table = { version = "=0.13.1", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.13.1", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.13.1", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "51.0.0", optional = false, features = ["prettyprint"] } diff --git a/docs/integrations/pytorch.rst b/docs/integrations/pytorch.rst index 0b49e2ac14..58d7285203 100644 --- a/docs/integrations/pytorch.rst +++ b/docs/integrations/pytorch.rst @@ -74,3 +74,8 @@ Available samplers: - :class:`lance.sampler.ShardedFragmentSampler` - :class:`lance.sampler.ShardedBatchSampler` + +.. warning:: + For multiprocessing you should probably not use fork as lance is + multi-threaded internally and fork and multi-thread do not work well. + Refer to `this discussion `_. diff --git a/docs/integrations/tensorflow.rst b/docs/integrations/tensorflow.rst index 37ed235316..c381abd363 100644 --- a/docs/integrations/tensorflow.rst +++ b/docs/integrations/tensorflow.rst @@ -88,4 +88,7 @@ workers. for batch in ds: print(batch) - +.. warning:: + For multiprocessing you should probably not use fork as lance is + multi-threaded internally and fork and multi-thread do not work well. + Refer to `this discussion `_. diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index 24d9863b22..2265822c11 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -34,10 +34,6 @@ use crate::{ RT, }; -/////////////////// -// Write Methods // -/////////////////// - ////////////////// // Read Methods // ////////////////// @@ -70,6 +66,9 @@ fn inner_count_rows_native( Ok(res) } +/////////////////// +// Write Methods // +/////////////////// #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local>( mut env: JNIEnv<'local>, diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 2e7df86812..7228684264 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -213,7 +213,7 @@ public long latestVersion() { /** * Count the number of rows in the dataset. * - * @return num of rows. + * @return num of rows */ public int countRows() { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { diff --git a/protos/encodings.proto b/protos/encodings.proto index 36ce5a7234..ffd0216afe 100644 --- a/protos/encodings.proto +++ b/protos/encodings.proto @@ -183,6 +183,7 @@ message SimpleStruct {} message Binary { ArrayEncoding indices = 1; ArrayEncoding bytes = 2; + uint64 null_adjustment = 3; } // Encodings that decode into an Arrow array diff --git a/protos/file2.proto b/protos/file2.proto index 63fc379091..31c9b50d49 100644 --- a/protos/file2.proto +++ b/protos/file2.proto @@ -15,10 +15,9 @@ import "google/protobuf/empty.proto"; // // * Each Lance file contains between 0 and 4Gi columns // * Each column contains between 0 and 4Gi pages -// * Each page contains between 0 and 4Gi items +// * Each page contains between 0 and 2^64 items // * Different pages within a column can have different items counts -// * Columns may have more than 4Gi items, though this will require more than -// one page +// * Columns may have up to 2^64 items // * Different columns within a file can have different item counts // // The Lance file format does not have any notion of a type system or schemas. @@ -178,7 +177,7 @@ message ColumnMetadata { // may be empty. repeated uint64 buffer_sizes = 2; // Logical length (e.g. # rows) of the page - uint32 length = 3; + uint64 length = 3; // The encoding used to encode the page Encoding encoding = 4; } diff --git a/protos/table.proto b/protos/table.proto index 8daca05e1f..ff500cdddb 100644 --- a/protos/table.proto +++ b/protos/table.proto @@ -235,7 +235,7 @@ message DataFile { // - dimension: packed-struct (0): // - x: u32 (1) // - y: u32 (2) - // - path: string (3) + // - path: list (3) // - embedding: fsl<768> (4) // - fp64 // - borders: fsl<4> (5) @@ -249,7 +249,7 @@ message DataFile { // This reflects quite a few phenomenon: // - The packed struct is encoded into a single column and there is no top-level column // for the x or y fields - // - The string is encoded into two columns + // - The variable sized list is encoded into two columns // - The embedding is encoded into a single column (common for FSL of primitive) and there // is not "FSL column" // - The borders field actually does have an "FSL column" diff --git a/python/Cargo.toml b/python/Cargo.toml index f22c570c2f..fddf1627fe 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.12.4" +version = "0.13.1" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" diff --git a/python/python/lance/util.py b/python/python/lance/util.py index b3ce35ded6..208d8b41ff 100644 --- a/python/python/lance/util.py +++ b/python/python/lance/util.py @@ -11,7 +11,7 @@ from .dependencies import _check_for_numpy, _check_for_pandas from .dependencies import numpy as np from .dependencies import pandas as pd -from .lance import _build_sq_storage, _Hnsw, _KMeans +from .lance import _Hnsw, _KMeans if TYPE_CHECKING: ts_types = Union[datetime, pd.Timestamp, str] @@ -245,9 +245,3 @@ def to_lance_file(self, file_path): def vectors(self) -> pa.Array: return self._hnsw.vectors() - - -def build_sq_storage( - row_ids_array: Iterator[pa.Array], vectors_array: pa.Array, dim, bounds: tuple -) -> pa.RecordBatch: - return _build_sq_storage(row_ids_array, vectors_array, dim, bounds) diff --git a/python/src/lib.rs b/python/src/lib.rs index e182582fdb..b90740cf5f 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -68,7 +68,6 @@ pub(crate) mod utils; pub use crate::arrow::{bfloat16_array, BFloat16}; use crate::fragment::{cleanup_partial_writes, write_fragments}; pub use crate::tracing::{trace_to_chrome, TraceGuard}; -use crate::utils::build_sq_storage; use crate::utils::Hnsw; use crate::utils::KMeans; pub use dataset::write_dataset; @@ -142,7 +141,6 @@ fn lance(py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(cleanup_partial_writes))?; m.add_wrapped(wrap_pyfunction!(trace_to_chrome))?; m.add_wrapped(wrap_pyfunction!(manifest_needs_migration))?; - m.add_wrapped(wrap_pyfunction!(build_sq_storage))?; // Debug functions m.add_wrapped(wrap_pyfunction!(debug::format_schema))?; m.add_wrapped(wrap_pyfunction!(debug::format_manifest))?; diff --git a/python/src/utils.rs b/python/src/utils.rs index 9d6c0636ba..101785de24 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -14,23 +14,18 @@ use std::sync::Arc; -use arrow::compute::{concat, concat_batches}; +use arrow::compute::concat; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; -use arrow_array::{ - cast::AsArray, Array, FixedSizeListArray, Float32Array, UInt32Array, UInt64Array, -}; +use arrow_array::{cast::AsArray, Array, FixedSizeListArray, Float32Array, UInt32Array}; use arrow_data::ArrayData; use arrow_schema::DataType; use lance::Result; -use lance::{datatypes::Schema, index::vector::sq, io::ObjectStore}; +use lance::{datatypes::Schema, io::ObjectStore}; use lance_arrow::FixedSizeListArrayExt; use lance_file::writer::FileWriter; use lance_index::scalar::IndexWriter; +use lance_index::vector::hnsw::{builder::HnswBuildParams, HNSW}; use lance_index::vector::v3::subindex::IvfSubIndex; -use lance_index::vector::{ - hnsw::{builder::HnswBuildParams, HNSW}, - storage::VectorStore, -}; use lance_linalg::kmeans::compute_partitions; use lance_linalg::{ distance::DistanceType, @@ -41,7 +36,7 @@ use object_store::path::Path; use pyo3::{ exceptions::{PyIOError, PyRuntimeError, PyValueError}, prelude::*, - types::{PyIterator, PyTuple}, + types::PyIterator, }; use crate::RT; @@ -153,12 +148,14 @@ impl Hnsw { max_level=7, m=20, ef_construction=100, + distance_type="l2", ))] fn build( vectors_array: &PyIterator, max_level: u16, m: usize, ef_construction: usize, + distance_type: &str, ) -> PyResult { let params = HnswBuildParams::default() .max_level(max_level) @@ -177,9 +174,12 @@ impl Hnsw { let vectors = concat(&array_refs).map_err(|e| PyIOError::new_err(e.to_string()))?; std::mem::drop(data); + let dt = DistanceType::try_from(distance_type) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + let hnsw = RT .runtime - .block_on(params.build(vectors.clone())) + .block_on(params.build(vectors.clone(), dt)) .map_err(|e| PyIOError::new_err(e.to_string()))?; Ok(Self { hnsw, vectors }) } @@ -215,40 +215,3 @@ impl Hnsw { self.vectors.to_data().to_pyarrow(py) } } - -#[pyfunction(name = "_build_sq_storage")] -pub fn build_sq_storage( - py: Python, - row_ids_array: &PyIterator, - vectors: &PyAny, - dim: usize, - bounds: &PyTuple, -) -> PyResult { - let mut row_ids_arr: Vec> = Vec::new(); - for row_ids in row_ids_array { - let row_ids = ArrayData::from_pyarrow(row_ids?)?; - if !matches!(row_ids.data_type(), DataType::UInt64) { - return Err(PyValueError::new_err("Must be a UInt64")); - } - row_ids_arr.push(Arc::new(UInt64Array::from(row_ids))); - } - let row_ids_refs = row_ids_arr.iter().map(|a| a.as_ref()).collect::>(); - let row_ids = concat(&row_ids_refs).map_err(|e| PyIOError::new_err(e.to_string()))?; - std::mem::drop(row_ids_arr); - - let vectors = Arc::new(FixedSizeListArray::from(ArrayData::from_pyarrow(vectors)?)); - - let lower_bound = bounds.get_item(0)?.extract::()?; - let upper_bound = bounds.get_item(1)?.extract::()?; - let quantizer = - lance_index::vector::sq::ScalarQuantizer::with_bounds(8, dim, lower_bound..upper_bound); - let storage = sq::build_sq_storage(DistanceType::L2, row_ids, vectors, quantizer) - .map_err(|e| PyIOError::new_err(e.to_string()))?; - let batches = storage - .to_batches() - .map_err(|e| PyIOError::new_err(e.to_string()))? - .collect::>(); - let batch = concat_batches(&batches[0].schema(), &batches) - .map_err(|e| PyIOError::new_err(e.to_string()))?; - batch.to_pyarrow(py) -} diff --git a/rust/lance-arrow/src/lib.rs b/rust/lance-arrow/src/lib.rs index e1cf554d59..059221c268 100644 --- a/rust/lance-arrow/src/lib.rs +++ b/rust/lance-arrow/src/lib.rs @@ -5,6 +5,7 @@ //! //! To improve Arrow-RS ergonomic +use std::collections::HashMap; use std::sync::Arc; use arrow_array::{ @@ -374,6 +375,19 @@ pub trait RecordBatchExt { /// Project the schema over the [RecordBatch]. fn project_by_schema(&self, schema: &Schema) -> Result; + /// metadata of the schema. + fn metadata(&self) -> &HashMap; + + /// Add metadata to the schema. + fn add_metadata(&self, key: String, value: String) -> Result { + let mut metadata = self.metadata().clone(); + metadata.insert(key, value); + self.with_metadata(metadata) + } + + /// Replace the schema metadata with the provided one. + fn with_metadata(&self, metadata: HashMap) -> Result; + /// Take selected rows from the [RecordBatch]. fn take(&self, indices: &UInt32Array) -> Result; } @@ -460,6 +474,16 @@ impl RecordBatchExt for RecordBatch { self.try_new_from_struct_array(project(&struct_array, schema.fields())?) } + fn metadata(&self) -> &HashMap { + self.schema_ref().metadata() + } + + fn with_metadata(&self, metadata: HashMap) -> Result { + let mut schema = self.schema_ref().as_ref().clone(); + schema.metadata = metadata; + Self::try_new(schema.into(), self.columns().into()) + } + fn take(&self, indices: &UInt32Array) -> Result { let struct_array: StructArray = self.clone().into(); let taken = take(&struct_array, indices, None)?; diff --git a/rust/lance-core/src/utils/cpu.rs b/rust/lance-core/src/utils/cpu.rs index b2e0ba31bc..4427922dd1 100644 --- a/rust/lance-core/src/utils/cpu.rs +++ b/rust/lance-core/src/utils/cpu.rs @@ -92,6 +92,14 @@ mod aarch64 { } } +#[cfg(all(target_arch = "aarch64", target_os = "windows"))] +mod aarch64 { + pub fn has_neon_f16_support() -> bool { + // https://github.com/lancedb/lance/issues/2411 + false + } +} + #[cfg(target_arch = "loongarch64")] mod loongarch64 { pub fn has_lsx_support() -> bool { diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index c684c89635..e110680e1f 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -222,6 +222,7 @@ use bytes::{Bytes, BytesMut}; use futures::future::BoxFuture; use futures::stream::{BoxStream, FuturesOrdered}; use futures::{FutureExt, StreamExt, TryStreamExt}; +use lance_arrow::DataTypeExt; use lance_core::datatypes::{Field, Schema}; use log::trace; use snafu::{location, Location}; @@ -230,9 +231,7 @@ use tokio::sync::mpsc::{self, unbounded_channel}; use lance_core::{Error, Result}; use tracing::instrument; -use crate::encoder::get_str_encoding_type; use crate::encoder::{values_column_encoding, EncodedBatch}; -use crate::encodings::logical::binary::BinaryFieldScheduler; use crate::encodings::logical::list::{ListFieldScheduler, OffsetPageInfo}; use crate::encodings::logical::primitive::PrimitiveFieldScheduler; use crate::encodings::logical::r#struct::{SimpleStructDecoder, SimpleStructScheduler}; @@ -246,7 +245,7 @@ use crate::{BufferScheduler, EncodingsIo}; #[derive(Debug)] pub struct PageInfo { /// The number of rows in the page - pub num_rows: u32, + pub num_rows: u64, /// The encoding that explains the buffers in the page pub encoding: pb::ArrayEncoding, /// The offsets and sizes of the buffers in the file @@ -550,26 +549,12 @@ impl CoreFieldDecoderStrategy { } fn is_primitive(data_type: &DataType) -> bool { - if data_type.is_primitive() { + if data_type.is_primitive() | data_type.is_binary_like() { true - } else if get_str_encoding_type() { - match data_type { - // DataType::is_primitive doesn't consider these primitive but we do - DataType::Boolean - | DataType::Null - | DataType::FixedSizeBinary(_) - | DataType::Utf8 => true, - DataType::FixedSizeList(inner, _) => Self::is_primitive(inner.data_type()), - _ => false, - } } else { match data_type { // DataType::is_primitive doesn't consider these primitive but we do - DataType::Boolean - | DataType::Null - | DataType::FixedSizeBinary(_) - // | DataType::Utf8 - => true, + DataType::Boolean | DataType::Null | DataType::FixedSizeBinary(_) => true, DataType::FixedSizeList(inner, _) => Self::is_primitive(inner.data_type()), _ => false, } @@ -598,7 +583,7 @@ impl CoreFieldDecoderStrategy { /// Helper method to verify the page encoding of a struct header column fn check_simple_struct(column_info: &ColumnInfo, path: &VecDeque) -> Result<()> { Self::ensure_values_encoded(column_info, path)?; - if !column_info.page_infos.len() == 1 { + if column_info.page_infos.len() != 1 { return Err(Error::InvalidInput { source: format!("Due to schema we expected a struct column but we received a column with {} pages and right now we only support struct columns with 1 page", column_info.page_infos.len()).into(), location: location!() }); } let encoding = &column_info.page_infos[0].encoding; @@ -711,28 +696,6 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy { .boxed(); Ok((chain, list_scheduler_fut)) } - DataType::Utf8 | DataType::Binary | DataType::LargeBinary | DataType::LargeUtf8 => { - let list_type = if matches!(data_type, DataType::Utf8 | DataType::Binary) { - DataType::List(Arc::new(ArrowField::new("item", DataType::UInt8, true))) - } else { - DataType::LargeList(Arc::new(ArrowField::new("item", DataType::UInt8, true))) - }; - let list_field = ArrowField::new(&field.name, list_type, true); - let list_field = Field::try_from(&list_field).unwrap(); - // We've changed the data type but are still decoding the same "field" - let (chain, list_decoder) = - chain.restart_at_current(&list_field, column_infos, buffers)?; - let data_type = data_type.clone(); - let binary_scheduler_fut = async move { - let list_decoder = list_decoder.await?; - Ok( - Arc::new(BinaryFieldScheduler::new(list_decoder, data_type.clone())) - as Arc, - ) - } - .boxed(); - Ok((chain, binary_scheduler_fut)) - } DataType::Struct(fields) => { let column_info = column_infos.pop_front().unwrap(); Self::check_simple_struct(&column_info, chain.current_path()).unwrap(); @@ -775,9 +738,9 @@ fn root_column(num_rows: u64) -> ColumnInfo { let root_pages = (0..num_root_pages) .map(|i| PageInfo { num_rows: if i == num_root_pages - 1 { - final_page_num_rows as u32 + final_page_num_rows } else { - u32::MAX + u64::MAX }, encoding: pb::ArrayEncoding { array_encoding: Some(pb::array_encoding::ArrayEncoding::Struct( @@ -861,8 +824,8 @@ impl DecodeBatchScheduler { return; } let next_scan_line = maybe_next_scan_line.unwrap(); - num_rows_scheduled += next_scan_line.rows_scheduled as u64; - rows_to_schedule -= next_scan_line.rows_scheduled as u64; + num_rows_scheduled += next_scan_line.rows_scheduled; + rows_to_schedule -= next_scan_line.rows_scheduled; trace!( "Scheduled scan line of {} rows and {} decoders", next_scan_line.rows_scheduled, @@ -1060,11 +1023,10 @@ impl BatchDecodeStream { return Ok(None); } - let mut to_take = self.rows_remaining.min(self.rows_per_batch as u64) as u32; - self.rows_remaining -= to_take as u64; + let mut to_take = self.rows_remaining.min(self.rows_per_batch as u64); + self.rows_remaining -= to_take; - let scheduled_need = - (self.rows_drained + to_take as u64).saturating_sub(self.rows_scheduled); + let scheduled_need = (self.rows_drained + to_take).saturating_sub(self.rows_scheduled); trace!("scheduled_need = {} because rows_drained = {} and to_take = {} and rows_scheduled = {}", scheduled_need, self.rows_drained, to_take, self.rows_scheduled); if scheduled_need > 0 { let desired_scheduled = scheduled_need + self.rows_scheduled; @@ -1075,7 +1037,7 @@ impl BatchDecodeStream { let actually_scheduled = self.wait_for_scheduled(desired_scheduled).await?; if actually_scheduled < desired_scheduled { let under_scheduled = desired_scheduled - actually_scheduled; - to_take -= under_scheduled as u32; + to_take -= under_scheduled; } } @@ -1083,17 +1045,17 @@ impl BatchDecodeStream { return Ok(None); } - let avail = self.root_decoder.avail_u64(); + let avail = self.root_decoder.avail(); trace!("Top level page has {} rows already available", avail); - if avail < to_take as u64 { + if avail < to_take { trace!( "Top level page waiting for an additional {} rows", - to_take as u64 - avail + to_take - avail ); self.root_decoder.wait(to_take).await?; } let next_task = self.root_decoder.drain(to_take)?; - self.rows_drained += to_take as u64; + self.rows_drained += to_take; Ok(Some(next_task)) } @@ -1125,7 +1087,12 @@ impl BatchDecodeStream { }); next_task.map(|(task, num_rows)| { let task = task.map(|join_wrapper| join_wrapper.unwrap()).boxed(); - let next_task = ReadBatchTask { task, num_rows }; + // This should be true since batch size is u32 + debug_assert!(num_rows <= u32::MAX as u64); + let next_task = ReadBatchTask { + task, + num_rows: num_rows as u32, + }; (next_task, slf) }) }); @@ -1144,50 +1111,44 @@ impl BatchDecodeStream { /// the decode task for batch 0 and the decode task for batch 1. /// /// See [`crate::decoder`] for more information -pub trait PhysicalPageDecoder: Send + Sync { - /// Calculates and updates the capacity required to represent the requested data +pub trait PrimitivePageDecoder: Send + Sync { + /// Decode data into buffers + /// + /// This may be a simple zero-copy from a disk buffer or could involve complex decoding + /// such as decompressing from some compressed representation. /// /// Capacity is stored as a tuple of (num_bytes: u64, is_needed: bool). The `is_needed` /// portion only needs to be updated if the encoding has some concept of an "optional" /// buffer. /// - /// The decoder should look at `rows_to_skip` and `num_rows` and then calculate how - /// many bytes of data are needed. It should then update the first part of the tuple. - /// - /// Note: Most encodings deal with a single buffer. They may have multiple input buffers - /// but they only have a single output buffer. The current exception to this rule is the - /// `basic` encoding which has an output "validity" buffer and an output "values" buffers. - /// We may find there are other such exceptions. + /// Encodings can have any number of input or output buffers. For example, a dictionary + /// decoding will convert two buffers (indices + dictionary) into a single buffer /// - /// # Arguments + /// Binary decodings have two output buffers (one for values, one for offsets) /// - /// * `rows_to_skip` - how many rows to skip (within the page) before decoding - /// * `num_rows` - how many rows to decode - /// * `buffers` - A mutable slice of "capacities" (as described above), one per buffer - /// * `all_null` - A mutable bool, set to true if a decoder determines all values are null - fn update_capacity( - &self, - rows_to_skip: u32, - num_rows: u32, - buffers: &mut [(u64, bool)], - all_null: &mut bool, - ); - /// Decodes the data into the requested buffers. + /// Other decodings could even expand the # of output buffers. For example, we could decode + /// fixed size strings into variable length strings going from one input buffer to multiple output + /// buffers. /// - /// You can assume that the capacity will have already been configured on the `BytesMut` - /// according to the capacity calculated in [`PhysicalPageDecoder::update_capacity`] + /// Each Arrow data type typically has a fixed structure of buffers and the encoding chain will + /// generally end at one of these structures. However, intermediate structures may exist which + /// do not correspond to any Arrow type at all. For example, a bitpacking encoding will deal + /// with buffers that have bits-per-value that is not a multiple of 8. /// + /// The `primitive_array_from_buffers` method has an expected buffer layout for each arrow + /// type (order matters) and encodings that aim to decode into arrow types should respect + /// this layout. /// # Arguments /// /// * `rows_to_skip` - how many rows to skip (within the page) before decoding /// * `num_rows` - how many rows to decode - /// * `dest_buffers` - the output buffers to decode into - fn decode_into( + /// * `all_null` - A mutable bool, set to true if a decoder determines all values are null + fn decode( &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [BytesMut], - ) -> Result<()>; + rows_to_skip: u64, + num_rows: u64, + all_null: &mut bool, + ) -> Result>; fn num_buffers(&self) -> u32; } @@ -1199,7 +1160,6 @@ pub trait PhysicalPageDecoder: Send + Sync { /// be shared in follow-up I/O tasks. /// /// See [`crate::decoder`] for more information - pub trait PageScheduler: Send + Sync + std::fmt::Debug { /// Schedules a batch of I/O to load the data needed for the requested ranges /// @@ -1214,10 +1174,10 @@ pub trait PageScheduler: Send + Sync + std::fmt::Debug { /// scheduled. This can be used to assign priority to I/O requests fn schedule_ranges( &self, - ranges: &[Range], + ranges: &[Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>>; + ) -> BoxFuture<'static, Result>>; } /// Contains the context for a scheduler @@ -1290,7 +1250,7 @@ impl SchedulerContext { #[derive(Debug)] pub struct ScheduledScanLine { - pub rows_scheduled: u32, + pub rows_scheduled: u64, pub decoders: Vec, } @@ -1372,7 +1332,7 @@ pub struct NextDecodeTask { /// The decode task itself pub task: Box, /// The number of rows that will be created - pub num_rows: u32, + pub num_rows: u64, /// Whether or not the decoder that created this still has more rows to decode pub has_more: bool, } @@ -1440,13 +1400,13 @@ pub trait LogicalPageDecoder: std::fmt::Debug + Send { }) } /// Waits for enough data to be loaded to decode `num_rows` of data - fn wait(&mut self, num_rows: u32) -> BoxFuture>; + fn wait(&mut self, num_rows: u64) -> BoxFuture>; /// Creates a task to decode `num_rows` of data into an array - fn drain(&mut self, num_rows: u32) -> Result; + fn drain(&mut self, num_rows: u64) -> Result; /// The number of rows that are in the page but haven't yet been "waited" - fn unawaited(&self) -> u32; + fn unawaited(&self) -> u64; /// The number of rows that have been "waited" but not yet decoded - fn avail(&self) -> u32; + fn avail(&self) -> u64; /// The data type of the decoded data fn data_type(&self) -> &DataType; } diff --git a/rust/lance-encoding/src/encoder.rs b/rust/lance-encoding/src/encoder.rs index 33c82cede3..ca368f7a9a 100644 --- a/rust/lance-encoding/src/encoder.rs +++ b/rust/lance-encoding/src/encoder.rs @@ -15,8 +15,7 @@ use crate::{ decoder::{ColumnInfo, PageInfo}, encodings::{ logical::{ - binary::BinaryFieldEncoder, list::ListFieldEncoder, primitive::PrimitiveFieldEncoder, - r#struct::StructFieldEncoder, + list::ListFieldEncoder, primitive::PrimitiveFieldEncoder, r#struct::StructFieldEncoder, }, physical::{ basic::BasicEncoder, binary::BinaryEncoder, fixed_size_list::FslEncoder, @@ -102,7 +101,7 @@ pub struct EncodedPage { // The encoded array data pub array: EncodedArray, /// The number of rows in the encoded page - pub num_rows: u32, + pub num_rows: u64, /// The index of the column pub column_idx: u32, } @@ -228,11 +227,6 @@ fn get_compression_scheme() -> CompressionScheme { parse_compression_scheme(&compression_scheme).unwrap_or(CompressionScheme::None) } -pub fn get_str_encoding_type() -> bool { - let str_encoding = std::env::var("LANCE_STR_ARRAY_ENCODING").unwrap_or("none".to_string()); - matches!(str_encoding.as_str(), "binary") -} - impl CoreArrayEncodingStrategy { fn array_encoder_from_type(data_type: &DataType) -> Result> { match data_type { @@ -242,20 +236,14 @@ impl CoreArrayEncodingStrategy { *dimension as u32, ))))) } - DataType::Utf8 => { - if get_str_encoding_type() { - let bin_indices_encoder = Self::array_encoder_from_type(&DataType::UInt64)?; - let bin_bytes_encoder = Self::array_encoder_from_type(&DataType::UInt8)?; - - Ok(Box::new(BinaryEncoder::new( - bin_indices_encoder, - bin_bytes_encoder, - ))) - } else { - Ok(Box::new(BasicEncoder::new(Box::new( - ValueEncoder::try_new(data_type, get_compression_scheme())?, - )))) - } + DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => { + let bin_indices_encoder = Self::array_encoder_from_type(&DataType::UInt64)?; + let bin_bytes_encoder = Self::array_encoder_from_type(&DataType::UInt8)?; + + Ok(Box::new(BinaryEncoder::new( + bin_indices_encoder, + bin_bytes_encoder, + ))) } _ => Ok(Box::new(BasicEncoder::new(Box::new( ValueEncoder::try_new(data_type, get_compression_scheme())?, @@ -369,30 +357,16 @@ impl FieldEncodingStrategy for CoreFieldEncodingStrategy { | DataType::UInt64 | DataType::UInt8 | DataType::FixedSizeBinary(_) - | DataType::FixedSizeList(_, _) => Ok(Box::new(PrimitiveFieldEncoder::try_new( + | DataType::FixedSizeList(_, _) + | DataType::Binary + | DataType::LargeBinary + | DataType::Utf8 + | DataType::LargeUtf8 => Ok(Box::new(PrimitiveFieldEncoder::try_new( cache_bytes_per_column, keep_original_array, self.array_encoding_strategy.clone(), column_index.next_column_index(field.id), )?)), - DataType::Utf8 => { - if get_str_encoding_type() { - Ok(Box::new(PrimitiveFieldEncoder::try_new( - cache_bytes_per_column, - keep_original_array, - self.array_encoding_strategy.clone(), - column_index.next_column_index(field.id), - )?)) - } else { - let list_idx = column_index.next_column_index(field.id); - column_index.skip(); - Ok(Box::new(BinaryFieldEncoder::new( - cache_bytes_per_column, - keep_original_array, - list_idx, - ))) - } - } DataType::List(child) => { let list_idx = column_index.next_column_index(field.id); let inner_encoding = encoding_strategy_root.create_field_encoder( @@ -431,15 +405,6 @@ impl FieldEncodingStrategy for CoreFieldEncodingStrategy { header_idx, ))) } - DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => { - let list_idx = column_index.next_column_index(field.id); - column_index.skip(); - Ok(Box::new(BinaryFieldEncoder::new( - cache_bytes_per_column, - keep_original_array, - list_idx, - ))) - } _ => todo!("Implement encoding for field {}", field), } } diff --git a/rust/lance-encoding/src/encodings/logical.rs b/rust/lance-encoding/src/encodings/logical.rs index adca795788..cad0535571 100644 --- a/rust/lance-encoding/src/encodings/logical.rs +++ b/rust/lance-encoding/src/encodings/logical.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -pub mod binary; pub mod list; pub mod primitive; pub mod r#struct; diff --git a/rust/lance-encoding/src/encodings/logical/binary.rs b/rust/lance-encoding/src/encodings/logical/binary.rs deleted file mode 100644 index 79c3847326..0000000000 --- a/rust/lance-encoding/src/encodings/logical/binary.rs +++ /dev/null @@ -1,296 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -use std::sync::Arc; - -use arrow_array::{ - cast::AsArray, - types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, UInt8Type, Utf8Type}, - Array, ArrayRef, GenericByteArray, GenericListArray, LargeListArray, ListArray, UInt8Array, -}; - -use arrow_buffer::ScalarBuffer; -use arrow_schema::{DataType, Field}; -use futures::future::BoxFuture; -use lance_core::Result; -use log::trace; - -use crate::{ - decoder::{ - DecodeArrayTask, DecoderReady, FieldScheduler, FilterExpression, LogicalPageDecoder, - NextDecodeTask, ScheduledScanLine, SchedulerContext, SchedulingJob, - }, - encoder::{CoreArrayEncodingStrategy, EncodeTask, FieldEncoder}, -}; - -use super::{list::ListFieldEncoder, primitive::PrimitiveFieldEncoder}; - -/// Wraps a varbin scheduler and uses a BinaryPageDecoder to cast -/// the result to the appropriate type -#[derive(Debug)] -pub struct BinarySchedulingJob<'a> { - scheduler: &'a BinaryFieldScheduler, - inner: Box, -} - -impl<'a> SchedulingJob for BinarySchedulingJob<'a> { - fn schedule_next( - &mut self, - context: &mut SchedulerContext, - top_level_row: u64, - ) -> Result { - let inner_scan = self.inner.schedule_next(context, top_level_row)?; - let wrapped_decoders = inner_scan - .decoders - .into_iter() - .map(|decoder| DecoderReady { - path: decoder.path, - decoder: Box::new(BinaryPageDecoder { - inner: decoder.decoder, - data_type: self.scheduler.data_type.clone(), - }), - }) - .collect::>(); - Ok(ScheduledScanLine { - decoders: wrapped_decoders, - rows_scheduled: inner_scan.rows_scheduled, - }) - } - - fn num_rows(&self) -> u64 { - self.inner.num_rows() - } -} - -/// A logical scheduler for utf8/binary pages which assumes the data are encoded as List -#[derive(Debug)] -pub struct BinaryFieldScheduler { - varbin_scheduler: Arc, - data_type: DataType, -} - -impl BinaryFieldScheduler { - // Create a new ListPageScheduler - pub fn new(varbin_scheduler: Arc, data_type: DataType) -> Self { - Self { - varbin_scheduler, - data_type, - } - } -} - -impl FieldScheduler for BinaryFieldScheduler { - fn schedule_ranges<'a>( - &'a self, - ranges: &[std::ops::Range], - filter: &FilterExpression, - ) -> Result> { - trace!("Scheduling binary for {} ranges", ranges.len()); - let varbin_job = self.varbin_scheduler.schedule_ranges(ranges, filter)?; - Ok(Box::new(BinarySchedulingJob { - scheduler: self, - inner: varbin_job, - })) - } - - fn num_rows(&self) -> u64 { - self.varbin_scheduler.num_rows() - } -} - -#[derive(Debug)] -pub struct BinaryPageDecoder { - inner: Box, - data_type: DataType, -} - -impl LogicalPageDecoder for BinaryPageDecoder { - fn wait(&mut self, num_rows: u32) -> BoxFuture> { - self.inner.wait(num_rows) - } - - fn drain(&mut self, num_rows: u32) -> Result { - let inner_task = self.inner.drain(num_rows)?; - Ok(NextDecodeTask { - has_more: inner_task.has_more, - num_rows: inner_task.num_rows, - task: Box::new(BinaryArrayDecoder { - inner: inner_task.task, - data_type: self.data_type.clone(), - }), - }) - } - - fn unawaited(&self) -> u32 { - self.inner.unawaited() - } - - fn avail(&self) -> u32 { - self.inner.avail() - } - - fn data_type(&self) -> &DataType { - &self.data_type - } -} - -pub struct BinaryArrayDecoder { - inner: Box, - data_type: DataType, -} - -impl BinaryArrayDecoder { - fn from_list_array(array: &GenericListArray) -> ArrayRef { - let values = array - .values() - .as_primitive::() - .values() - .inner() - .clone(); - let offsets = array.offsets().clone(); - Arc::new(GenericByteArray::::new( - offsets, - values, - array.nulls().cloned(), - )) - } -} - -impl DecodeArrayTask for BinaryArrayDecoder { - fn decode(self: Box) -> Result { - let data_type = self.data_type; - let arr = self.inner.decode()?; - match data_type { - DataType::Binary => Ok(Self::from_list_array::(arr.as_list::())), - DataType::LargeBinary => Ok(Self::from_list_array::( - arr.as_list::(), - )), - DataType::Utf8 => Ok(Self::from_list_array::(arr.as_list::())), - DataType::LargeUtf8 => Ok(Self::from_list_array::(arr.as_list::())), - _ => panic!("Binary decoder does not support this data type"), - } - } -} - -/// An encoder which encodes string arrays as List -pub struct BinaryFieldEncoder { - varbin_encoder: Box, -} - -impl BinaryFieldEncoder { - pub fn new(cache_bytes_per_column: u64, keep_original_array: bool, column_index: u32) -> Self { - let items_encoder = Box::new( - PrimitiveFieldEncoder::try_new( - cache_bytes_per_column, - keep_original_array, - Arc::new(CoreArrayEncodingStrategy), - column_index + 1, - ) - .unwrap(), - ); - Self { - varbin_encoder: Box::new(ListFieldEncoder::new( - items_encoder, - cache_bytes_per_column, - keep_original_array, - column_index, - )), - } - } - - fn byte_to_list_array>( - array: &GenericByteArray, - ) -> ListArray { - let values = UInt8Array::new( - ScalarBuffer::::new(array.values().clone(), 0, array.values().len()), - None, - ); - let list_field = Field::new("item", DataType::UInt8, true); - ListArray::new( - Arc::new(list_field), - array.offsets().clone(), - Arc::new(values), - array.nulls().cloned(), - ) - } - - fn byte_to_large_list_array>( - array: &GenericByteArray, - ) -> LargeListArray { - let values = UInt8Array::new( - ScalarBuffer::::new(array.values().clone(), 0, array.values().len()), - None, - ); - let list_field = Field::new("item", DataType::UInt8, true); - LargeListArray::new( - Arc::new(list_field), - array.offsets().clone(), - Arc::new(values), - array.nulls().cloned(), - ) - } - - fn to_list_array(array: ArrayRef) -> ArrayRef { - match array.data_type() { - DataType::Utf8 => Arc::new(Self::byte_to_list_array(array.as_string::())), - DataType::LargeUtf8 => { - Arc::new(Self::byte_to_large_list_array(array.as_string::())) - } - DataType::Binary => Arc::new(Self::byte_to_list_array(array.as_binary::())), - DataType::LargeBinary => { - Arc::new(Self::byte_to_large_list_array(array.as_binary::())) - } - _ => panic!("Binary encoder does not support {}", array.data_type()), - } - } -} - -impl FieldEncoder for BinaryFieldEncoder { - fn maybe_encode(&mut self, array: ArrayRef) -> Result> { - let list_array = Self::to_list_array(array); - self.varbin_encoder.maybe_encode(Arc::new(list_array)) - } - - fn flush(&mut self) -> Result> { - self.varbin_encoder.flush() - } - - fn num_columns(&self) -> u32 { - 2 - } - - fn finish(&mut self) -> BoxFuture<'_, Result>> { - self.varbin_encoder.finish() - } -} - -#[cfg(test)] -mod tests { - use arrow_schema::{DataType, Field}; - - use crate::testing::check_round_trip_encoding_random; - - #[test_log::test(tokio::test)] - async fn test_utf8() { - let field = Field::new("", DataType::Utf8, false); - check_round_trip_encoding_random(field).await; - } - - #[test_log::test(tokio::test)] - async fn test_binary() { - let field = Field::new("", DataType::Binary, false); - check_round_trip_encoding_random(field).await; - } - - #[test_log::test(tokio::test)] - async fn test_large_binary() { - let field = Field::new("", DataType::LargeBinary, true); - check_round_trip_encoding_random(field).await; - } - - #[test_log::test(tokio::test)] - async fn test_large_utf8() { - let field = Field::new("", DataType::LargeUtf8, true); - check_round_trip_encoding_random(field).await; - } -} diff --git a/rust/lance-encoding/src/encodings/logical/list.rs b/rust/lance-encoding/src/encodings/logical/list.rs index 48bfcfdd8b..fd7d290ef8 100644 --- a/rust/lance-encoding/src/encodings/logical/list.rs +++ b/rust/lance-encoding/src/encodings/logical/list.rs @@ -97,9 +97,9 @@ impl ListRequestsIter { let mut range = range.clone(); // Skip any offsets pages that are before the range - while offsets_offset + (cur_page_info.offsets_in_page as u64) <= range.start { + while offsets_offset + (cur_page_info.offsets_in_page) <= range.start { trace!("Skipping null offset adjustment chunk {:?}", offsets_offset); - offsets_offset += cur_page_info.offsets_in_page as u64; + offsets_offset += cur_page_info.offsets_in_page; items_offset += cur_page_info.num_items_referenced_by_page; cur_page_info = page_infos_iter.next().unwrap(); } @@ -118,7 +118,7 @@ impl ListRequestsIter { while !range.is_empty() { // The end of the list request is the min of the end of the range // and the end of the current page - let end = offsets_offset + cur_page_info.offsets_in_page as u64; + let end = offsets_offset + cur_page_info.offsets_in_page; let last = end >= range.end; let end = end.min(range.end); list_requests.push_back(ListRequest { @@ -133,7 +133,7 @@ impl ListRequestsIter { // If there is still more data in the range, we need to move to the // next page if !last { - offsets_offset += cur_page_info.offsets_in_page as u64; + offsets_offset += cur_page_info.offsets_in_page; items_offset += cur_page_info.num_items_referenced_by_page; cur_page_info = page_infos_iter.next().unwrap(); } @@ -146,7 +146,7 @@ impl ListRequestsIter { } // Given a page of offset data, grab the corresponding list requests - fn next(&mut self, mut num_offsets: u32) -> Vec { + fn next(&mut self, mut num_offsets: u64) -> Vec { let mut list_requests = Vec::new(); while num_offsets > 0 { let req = self.list_requests.front_mut().unwrap(); @@ -155,12 +155,12 @@ impl ListRequestsIter { num_offsets -= 1; debug_assert_ne!(num_offsets, 0); } - if num_offsets as u64 >= req.num_lists { - num_offsets -= req.num_lists as u32; + if num_offsets >= req.num_lists { + num_offsets -= req.num_lists; list_requests.push(self.list_requests.pop_front().unwrap()); } else { let sub_req = ListRequest { - num_lists: num_offsets as u64, + num_lists: num_offsets, includes_extra_offset: req.includes_extra_offset, null_offset_adjustment: req.null_offset_adjustment, items_offset: req.items_offset, @@ -168,7 +168,7 @@ impl ListRequestsIter { list_requests.push(sub_req); req.includes_extra_offset = false; - req.num_lists -= num_offsets as u64; + req.num_lists -= num_offsets; num_offsets = 0; } } @@ -430,7 +430,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> { debug_assert!(list_reqs .iter() .all(|req| req.null_offset_adjustment == null_offset_adjustment)); - let num_rows = list_reqs.iter().map(|req| req.num_lists).sum::() as u32; + let num_rows = list_reqs.iter().map(|req| req.num_lists).sum::(); // offsets is a uint64 which is guaranteed to create one decoder on each call to schedule_next let next_offsets_decoder = next_offsets.decoders.into_iter().next().unwrap().decoder; @@ -502,7 +502,7 @@ pub struct ListFieldScheduler { /// This is needed to construct the scheduler #[derive(Debug)] pub struct OffsetPageInfo { - pub offsets_in_page: u32, + pub offsets_in_page: u64, pub null_offset_adjustment: u64, pub num_items_referenced_by_page: u64, } @@ -570,9 +570,9 @@ struct ListPageDecoder { offsets: Vec, validity: BooleanBuffer, item_decoder: Option, - lists_available: u32, - num_rows: u32, - rows_drained: u32, + lists_available: u64, + num_rows: u64, + rows_drained: u64, items_type: DataType, offset_type: DataType, data_type: DataType, @@ -642,7 +642,7 @@ impl DecodeArrayTask for ListDecodeTask { } impl LogicalPageDecoder for ListPageDecoder { - fn wait(&mut self, num_rows: u32) -> BoxFuture> { + fn wait(&mut self, num_rows: u64) -> BoxFuture> { async move { // wait for the indirect I/O to finish, run the scheduler for the indirect // I/O and then wait for enough items to arrive @@ -679,7 +679,7 @@ impl LogicalPageDecoder for ListPageDecoder { self.offsets[offset_wait_start as usize + num_rows as usize] - item_start; if items_needed > 0 { // First discount any already available items - let items_already_available = self.item_decoder.as_mut().unwrap().avail_u64(); + let items_already_available = self.item_decoder.as_mut().unwrap().avail(); trace!( "List's items decoder needs {} items and already has {} items available", items_needed, @@ -687,7 +687,7 @@ impl LogicalPageDecoder for ListPageDecoder { ); items_needed = items_needed.saturating_sub(items_already_available); if items_needed > 0 { - self.item_decoder.as_mut().unwrap().wait_u64(items_needed).await?; + self.item_decoder.as_mut().unwrap().wait(items_needed).await?; } } // This is technically undercounting a little. It's possible that we loaded a big items @@ -700,11 +700,11 @@ impl LogicalPageDecoder for ListPageDecoder { .boxed() } - fn unawaited(&self) -> u32 { + fn unawaited(&self) -> u64 { self.num_rows - self.lists_available - self.rows_drained } - fn drain(&mut self, num_rows: u32) -> Result { + fn drain(&mut self, num_rows: u64) -> Result { self.lists_available -= num_rows; // We already have the offsets but need to drain the item pages let mut actual_num_rows = num_rows; @@ -745,7 +745,7 @@ impl LogicalPageDecoder for ListPageDecoder { } else { self.item_decoder .as_mut() - .map(|item_decoder| Result::Ok(item_decoder.drain_u64(num_items_to_drain)?.task)) + .map(|item_decoder| Result::Ok(item_decoder.drain(num_items_to_drain)?.task)) .transpose()? }; @@ -763,7 +763,7 @@ impl LogicalPageDecoder for ListPageDecoder { }) } - fn avail(&self) -> u32 { + fn avail(&self) -> u64 { self.lists_available } @@ -878,7 +878,7 @@ impl ListOffsetsEncoder { tokio::task::spawn(async move { let num_rows = offset_arrays.iter().map(|arr| arr.len()).sum::() - offset_arrays.len(); - let num_rows = num_rows as u32; + let num_rows = num_rows as u64; let mut buffer_index = 0; let array = Self::do_encode( offset_arrays, @@ -1012,7 +1012,7 @@ impl ListOffsetsEncoder { fn do_encode_u64( offset_arrays: Vec, validity: Vec>, - num_offsets: u32, + num_offsets: u64, null_offset_adjustment: u64, buffer_index: &mut u32, inner_encoder: Arc, @@ -1035,7 +1035,7 @@ impl ListOffsetsEncoder { offset_arrays: Vec, validity_arrays: Vec, buffer_index: &mut u32, - num_offsets: u32, + num_offsets: u64, inner_encoder: Arc, ) -> Result { let validity_arrays = validity_arrays diff --git a/rust/lance-encoding/src/encodings/logical/primitive.rs b/rust/lance-encoding/src/encodings/logical/primitive.rs index 003a179943..44dab86ffd 100644 --- a/rust/lance-encoding/src/encodings/logical/primitive.rs +++ b/rust/lance-encoding/src/encodings/logical/primitive.rs @@ -6,15 +6,17 @@ use std::{fmt::Debug, ops::Range, sync::Arc}; use arrow_array::{ new_null_array, types::{ - ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + ArrowPrimitiveType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + DurationSecondType, Float16Type, Float32Type, Float64Type, GenericBinaryType, + GenericStringType, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, }, - ArrayRef, BooleanArray, FixedSizeBinaryArray, FixedSizeListArray, PrimitiveArray, StringArray, + ArrayRef, BooleanArray, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, + PrimitiveArray, }; use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_schema::{DataType, IntervalUnit, TimeUnit}; @@ -29,7 +31,7 @@ use lance_core::{Error, Result}; use crate::{ decoder::{ DecodeArrayTask, FieldScheduler, FilterExpression, LogicalPageDecoder, NextDecodeTask, - PageInfo, PageScheduler, PhysicalPageDecoder, ScheduledScanLine, SchedulerContext, + PageInfo, PageScheduler, PrimitivePageDecoder, ScheduledScanLine, SchedulerContext, SchedulingJob, }, encoder::{ArrayEncodingStrategy, EncodeTask, EncodedColumn, EncodedPage, FieldEncoder}, @@ -39,7 +41,7 @@ use crate::{ #[derive(Debug)] struct PrimitivePage { scheduler: Box, - num_rows: u32, + num_rows: u64, } /// A field scheduler for primitive fields @@ -67,14 +69,15 @@ impl PrimitiveFieldScheduler { column_buffers: buffers, positions_and_sizes: &page.buffer_offsets_and_sizes, }; - let scheduler = decoder_from_array_encoding(&page.encoding, &page_buffers); + let scheduler = + decoder_from_array_encoding(&page.encoding, &page_buffers, &data_type); PrimitivePage { scheduler, num_rows: page.num_rows, } }) .collect::>(); - let num_rows = page_schedulers.iter().map(|p| p.num_rows as u64).sum(); + let num_rows = page_schedulers.iter().map(|p| p.num_rows).sum(); Self { data_type, page_schedulers, @@ -124,8 +127,8 @@ impl<'a> SchedulingJob for PrimitiveFieldSchedulingJob<'a> { cur_page.num_rows ); // Skip entire pages until we have some overlap with our next range - while cur_page.num_rows as u64 + self.global_row_offset <= range.start { - self.global_row_offset += cur_page.num_rows as u64; + while cur_page.num_rows + self.global_row_offset <= range.start { + self.global_row_offset += cur_page.num_rows; self.page_idx += 1; trace!("Skipping entire page of {} rows", cur_page.num_rows); cur_page = &self.scheduler.page_schedulers[self.page_idx]; @@ -135,14 +138,14 @@ impl<'a> SchedulingJob for PrimitiveFieldSchedulingJob<'a> { // until we find a range that exceeds the current page let mut ranges_in_page = Vec::new(); - while cur_page.num_rows as u64 + self.global_row_offset > range.start { + while cur_page.num_rows + self.global_row_offset > range.start { range.start = range.start.max(self.global_row_offset); let start_in_page = range.start - self.global_row_offset; let end_in_page = start_in_page + (range.end - range.start); - let end_in_page = end_in_page.min(cur_page.num_rows as u64) as u32; - let last_in_range = (end_in_page as u64 + self.global_row_offset) >= range.end; + let end_in_page = end_in_page.min(cur_page.num_rows); + let last_in_range = (end_in_page + self.global_row_offset) >= range.end; - ranges_in_page.push(start_in_page as u32..end_in_page); + ranges_in_page.push(start_in_page..end_in_page); if last_in_range { self.range_idx += 1; if self.range_idx == self.ranges.len() { @@ -162,7 +165,7 @@ impl<'a> SchedulingJob for PrimitiveFieldSchedulingJob<'a> { cur_page.num_rows ); - self.global_row_offset += cur_page.num_rows as u64; + self.global_row_offset += cur_page.num_rows; self.page_idx += 1; let physical_decoder = @@ -211,17 +214,17 @@ impl FieldScheduler for PrimitiveFieldScheduler { pub struct PrimitiveFieldDecoder { data_type: DataType, - unloaded_physical_decoder: Option>>>, - physical_decoder: Option>, - num_rows: u32, - rows_drained: u32, + unloaded_physical_decoder: Option>>>, + physical_decoder: Option>, + num_rows: u64, + rows_drained: u64, } impl PrimitiveFieldDecoder { pub fn new_from_data( - physical_decoder: Arc, + physical_decoder: Arc, data_type: DataType, - num_rows: u32, + num_rows: u64, ) -> Self { Self { data_type, @@ -244,48 +247,27 @@ impl Debug for PrimitiveFieldDecoder { } struct PrimitiveFieldDecodeTask { - rows_to_skip: u32, - rows_to_take: u32, - physical_decoder: Arc, + rows_to_skip: u64, + rows_to_take: u64, + physical_decoder: Arc, data_type: DataType, } impl DecodeArrayTask for PrimitiveFieldDecodeTask { fn decode(self: Box) -> Result { - // We start by assuming that no buffers are required. The number of buffers needed is based - // on the data type. Most data types need two buffers but each layer of fixed-size-list, for - // example, adds another validity buffer - let mut capacities = vec![(0, false); self.physical_decoder.num_buffers() as usize]; let mut all_null = false; - self.physical_decoder.update_capacity( - self.rows_to_skip, - self.rows_to_take, - &mut capacities, - &mut all_null, - ); + + // The number of buffers needed is based on the data type. + // Most data types need two buffers but each layer of fixed-size-list, for + // example, adds another validity buffer. + let bufs = + self.physical_decoder + .decode(self.rows_to_skip, self.rows_to_take, &mut all_null)?; if all_null { return Ok(new_null_array(&self.data_type, self.rows_to_take as usize)); } - // At this point we know the size needed for each buffer - let mut bufs = capacities - .into_iter() - .map(|(num_bytes, is_needed)| { - // Only allocate the validity buffer if it is needed, otherwise we - // create an empty BytesMut (does not require allocation) - if is_needed { - BytesMut::with_capacity(num_bytes as usize) - } else { - BytesMut::default() - } - }) - .collect::>(); - - // Go ahead and fill the validity / values buffers - self.physical_decoder - .decode_into(self.rows_to_skip, self.rows_to_take, &mut bufs)?; - // Convert the two buffers into an Arrow array Self::primitive_array_from_buffers(&self.data_type, bufs, self.rows_to_take) } @@ -297,7 +279,7 @@ impl PrimitiveFieldDecodeTask { // into a primitive array is pretty fundamental. fn new_primitive_array( buffers: Vec, - num_rows: u32, + num_rows: u64, data_type: &DataType, ) -> ArrayRef { let mut buffer_iter = buffers.into_iter(); @@ -323,7 +305,50 @@ impl PrimitiveFieldDecodeTask { ) } - fn bytes_to_validity(bytes: BytesMut, num_rows: u32) -> Option { + fn new_generic_byte_array(buffers: Vec, num_rows: u64) -> ArrayRef { + // iterate over buffers to get offsets and then bytes + let mut buffer_iter = buffers.into_iter(); + + let null_buffer = buffer_iter.next().unwrap(); + let null_buffer = if null_buffer.is_empty() { + None + } else { + let null_buffer = null_buffer.freeze().into(); + Some(NullBuffer::new(BooleanBuffer::new( + Buffer::from_bytes(null_buffer), + 0, + num_rows as usize, + ))) + }; + + let indices_bytes = buffer_iter.next().unwrap().freeze(); + let indices_buffer = Buffer::from_bytes(indices_bytes.into()); + let indices_buffer = + ScalarBuffer::::new(indices_buffer, 0, num_rows as usize + 1); + + let offsets = OffsetBuffer::new(indices_buffer.clone()); + + // TODO - add NULL support + // Decoding the bytes creates 2 buffers, the first one is empty due to nulls. + buffer_iter.next().unwrap(); + + let bytes_buffer = buffer_iter.next().unwrap().freeze(); + let bytes_buffer = Buffer::from_bytes(bytes_buffer.into()); + let bytes_buffer_len = bytes_buffer.len(); + let bytes_buffer = ScalarBuffer::::new(bytes_buffer, 0, bytes_buffer_len); + + let bytes_array = Arc::new( + PrimitiveArray::::new(bytes_buffer, None).with_data_type(DataType::UInt8), + ); + + Arc::new(GenericByteArray::::new( + offsets, + bytes_array.values().into(), + null_buffer, + )) + } + + fn bytes_to_validity(bytes: BytesMut, num_rows: u64) -> Option { if bytes.is_empty() { None } else { @@ -339,7 +364,7 @@ impl PrimitiveFieldDecodeTask { fn primitive_array_from_buffers( data_type: &DataType, buffers: Vec, - num_rows: u32, + num_rows: u64, ) -> Result { match data_type { DataType::Boolean => { @@ -484,7 +509,7 @@ impl PrimitiveFieldDecodeTask { let items_array = Self::primitive_array_from_buffers( items.data_type(), remaining_buffers, - num_rows * (*dimension as u32), + num_rows * (*dimension as u64), )?; Ok(Arc::new(FixedSizeListArray::new( items.clone(), @@ -493,36 +518,18 @@ impl PrimitiveFieldDecodeTask { fsl_nulls, ))) } - DataType::Utf8 => { - // iterate over buffers to get offsets and then bytes - let mut buffer_iter = buffers.into_iter(); - let indices_bytes = buffer_iter.next().unwrap().freeze(); - let indices_buffer = Buffer::from_bytes(indices_bytes.into()); - let indices_buffer = - ScalarBuffer::::new(indices_buffer, 0, num_rows as usize + 1); - - let offsets = OffsetBuffer::new(indices_buffer.clone()); - - // TODO - add NULL support - // Decoding the bytes creates 2 buffers, the first one is empty due to nulls. - let _null_buffer = buffer_iter.next().unwrap(); - - let bytes_buffer = buffer_iter.next().unwrap().freeze(); - let bytes_buffer = Buffer::from_bytes(bytes_buffer.into()); - let bytes_buffer_len = bytes_buffer.len(); - let bytes_buffer = ScalarBuffer::::new(bytes_buffer, 0, bytes_buffer_len); - - let bytes_array = Arc::new( - PrimitiveArray::::new(bytes_buffer, None) - .with_data_type(DataType::UInt8), - ); - - Ok(Arc::new(StringArray::new( - offsets, - bytes_array.values().into(), - None, - ))) - } + DataType::Utf8 => Ok(Self::new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::LargeUtf8 => Ok(Self::new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::Binary => Ok(Self::new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::LargeBinary => Ok(Self::new_generic_byte_array::>( + buffers, num_rows, + )), _ => Err(Error::io( format!( "The data type {} cannot be decoded from a primitive encoding", @@ -537,7 +544,7 @@ impl PrimitiveFieldDecodeTask { impl LogicalPageDecoder for PrimitiveFieldDecoder { // TODO: In the future, at some point, we may consider partially waiting for primitive pages by // breaking up large I/O into smaller I/O as a way to accelerate the "time-to-first-decode" - fn wait(&mut self, _: u32) -> BoxFuture> { + fn wait(&mut self, _: u64) -> BoxFuture> { async move { let physical_decoder = self.unloaded_physical_decoder.take().unwrap().await?; self.physical_decoder = Some(Arc::from(physical_decoder)); @@ -546,7 +553,7 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder { .boxed() } - fn drain(&mut self, num_rows: u32) -> Result { + fn drain(&mut self, num_rows: u64) -> Result { let rows_to_skip = self.rows_drained; let rows_to_take = num_rows; @@ -566,7 +573,7 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder { }) } - fn unawaited(&self) -> u32 { + fn unawaited(&self) -> u64 { if self.unloaded_physical_decoder.is_some() { self.num_rows } else { @@ -574,7 +581,7 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder { } } - fn avail(&self) -> u32 { + fn avail(&self) -> u64 { if self.unloaded_physical_decoder.is_some() { 0 } else { @@ -685,7 +692,7 @@ impl PrimitiveFieldEncoder { let column_idx = self.column_index; Ok(tokio::task::spawn(async move { - let num_rows = arrays.iter().map(|arr| arr.len() as u32).sum(); + let num_rows = arrays.iter().map(|arr| arr.len() as u64).sum(); let mut buffer_index = 0; let array = encoder.encode(&arrays, &mut buffer_index)?; Ok(EncodedPage { diff --git a/rust/lance-encoding/src/encodings/logical/struct.rs b/rust/lance-encoding/src/encodings/logical/struct.rs index 8fb5731469..6dc2609886 100644 --- a/rust/lance-encoding/src/encodings/logical/struct.rs +++ b/rust/lance-encoding/src/encodings/logical/struct.rs @@ -129,14 +129,14 @@ impl<'a> SchedulingJob for SimpleStructSchedulerJob<'a> { child_scan.rows_scheduled, next_child.col_idx ); - next_child.rows_scheduled += child_scan.rows_scheduled as u64; - next_child.rows_remaining -= child_scan.rows_scheduled as u64; + next_child.rows_scheduled += child_scan.rows_scheduled; + next_child.rows_remaining -= child_scan.rows_scheduled; decoders.extend(child_scan.decoders); self.children.push(next_child); self.rows_scheduled = self.children.peek().unwrap().rows_scheduled; context = scoped.pop(); } - let struct_rows_scheduled = (self.rows_scheduled - old_rows_scheduled) as u32; + let struct_rows_scheduled = self.rows_scheduled - old_rows_scheduled; Ok(ScheduledScanLine { decoders, rows_scheduled: struct_rows_scheduled, @@ -162,8 +162,6 @@ impl<'a> SchedulingJob for SimpleStructSchedulerJob<'a> { pub struct SimpleStructScheduler { children: Vec>, child_fields: Fields, - // A single page cannot contain more than u32 rows. However, we also use SimpleStructScheduler - // at the top level and a single file *can* contain more than u32 rows. num_rows: u64, } @@ -230,7 +228,7 @@ struct ChildState { struct CompositeDecodeTask { // One per child tasks: Vec>, - num_rows: u32, + num_rows: u64, has_more: bool, } @@ -273,7 +271,7 @@ impl ChildState { let mut remaining = num_rows.saturating_sub(self.rows_available); for next_decoder in &mut self.scheduled { if next_decoder.unawaited() > 0 { - let rows_to_wait = remaining.min(next_decoder.unawaited() as u64) as u32; + let rows_to_wait = remaining.min(next_decoder.unawaited()); trace!( "Struct await an additional {} rows from the current page", rows_to_wait @@ -289,14 +287,14 @@ impl ChildState { next_decoder.wait(rows_to_wait).await?; let newly_avail = next_decoder.avail() - previously_avail; trace!("The await loaded {} rows", newly_avail); - self.rows_available += newly_avail as u64; + self.rows_available += newly_avail; // Need to use saturating_sub here because we might have asked for range // 0-1000 and this page we just loaded might cover 900-1100 and so newly_avail // is 200 but rows_unawaited is only 100 // // TODO: Unit tests may not be covering this branch right now - self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail as u64); - remaining -= rows_to_wait as u64; + self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail); + remaining -= rows_to_wait; if remaining == 0 { break; } @@ -322,13 +320,13 @@ impl ChildState { }; while remaining > 0 { let next = self.scheduled.front_mut().unwrap(); - let rows_to_take = remaining.min(next.avail() as u64) as u32; + let rows_to_take = remaining.min(next.avail()); let next_task = next.drain(rows_to_take)?; if next.avail() == 0 && next.unawaited() == 0 { trace!("Completely drained page"); self.scheduled.pop_front(); } - remaining -= rows_to_take as u64; + remaining -= rows_to_take; composite.tasks.push(next_task.task); composite.num_rows += next_task.num_rows; } @@ -357,53 +355,6 @@ impl SimpleStructDecoder { data_type, } } - - pub fn avail_u64(&self) -> u64 { - self.children - .iter() - .map(|c| c.rows_available) - .min() - .unwrap() - } - - // Rows are unawaited if they are unawaited in any child column - pub fn unawaited_u64(&self) -> u64 { - self.children - .iter() - .map(|c| c.rows_unawaited) - .max() - .unwrap() - } - - pub fn wait_u64(&mut self, num_rows: u64) -> BoxFuture> { - async move { - for child in self.children.iter_mut() { - child.wait(num_rows).await?; - } - Ok(()) - } - .boxed() - } - - pub fn drain_u64(&mut self, num_rows: u64) -> Result { - let child_tasks = self - .children - .iter_mut() - .map(|child| child.drain(num_rows)) - .collect::>>()?; - let num_rows = child_tasks[0].num_rows; - let has_more = child_tasks[0].has_more; - debug_assert!(child_tasks.iter().all(|task| task.num_rows == num_rows)); - debug_assert!(child_tasks.iter().all(|task| task.has_more == has_more)); - Ok(NextDecodeTask { - task: Box::new(SimpleStructDecodeTask { - children: child_tasks, - child_fields: self.child_fields.clone(), - }), - num_rows, - has_more, - }) - } } impl LogicalPageDecoder for SimpleStructDecoder { @@ -423,21 +374,21 @@ impl LogicalPageDecoder for SimpleStructDecoder { Ok(()) } - fn wait(&mut self, num_rows: u32) -> BoxFuture> { + fn wait(&mut self, num_rows: u64) -> BoxFuture> { async move { for child in self.children.iter_mut() { - child.wait(num_rows as u64).await?; + child.wait(num_rows).await?; } Ok(()) } .boxed() } - fn drain(&mut self, num_rows: u32) -> Result { + fn drain(&mut self, num_rows: u64) -> Result { let child_tasks = self .children .iter_mut() - .map(|child| child.drain(num_rows as u64)) + .map(|child| child.drain(num_rows)) .collect::>>()?; let num_rows = child_tasks[0].num_rows; let has_more = child_tasks[0].has_more; @@ -454,17 +405,21 @@ impl LogicalPageDecoder for SimpleStructDecoder { } // Rows are available only if they are available in every child column - fn avail(&self) -> u32 { - let avail = self.avail_u64(); - debug_assert!(avail <= u32::MAX as u64); - avail as u32 + fn avail(&self) -> u64 { + self.children + .iter() + .map(|c| c.rows_available) + .min() + .unwrap() } // Rows are unawaited if they are unawaited in any child column - fn unawaited(&self) -> u32 { - let unawaited = self.unawaited_u64(); - debug_assert!(unawaited <= u32::MAX as u64); - unawaited as u32 + fn unawaited(&self) -> u64 { + self.children + .iter() + .map(|c| c.rows_unawaited) + .max() + .unwrap() } fn data_type(&self) -> &DataType { @@ -495,7 +450,7 @@ impl DecodeArrayTask for SimpleStructDecodeTask { pub struct StructFieldEncoder { children: Vec>, column_index: u32, - num_rows_seen: u32, + num_rows_seen: u64, } impl StructFieldEncoder { @@ -511,7 +466,7 @@ impl StructFieldEncoder { impl FieldEncoder for StructFieldEncoder { fn maybe_encode(&mut self, array: ArrayRef) -> Result> { - self.num_rows_seen += array.len() as u32; + self.num_rows_seen += array.len() as u64; let struct_array = array.as_struct(); let child_tasks = self .children diff --git a/rust/lance-encoding/src/encodings/physical.rs b/rust/lance-encoding/src/encodings/physical.rs index 0c963d4ae5..e3895daefe 100644 --- a/rust/lance-encoding/src/encodings/physical.rs +++ b/rust/lance-encoding/src/encodings/physical.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use arrow_schema::DataType; + use crate::encodings::physical::value::CompressionScheme; use crate::{decoder::PageScheduler, format::pb}; @@ -79,19 +81,30 @@ fn get_buffer_decoder(encoding: &pb::Flat, buffers: &PageBuffers) -> Box Box { match encoding.array_encoding.as_ref().unwrap() { pb::array_encoding::ArrayEncoding::Nullable(basic) => { match basic.nullability.as_ref().unwrap() { - pb::nullable::Nullability::NoNulls(no_nulls) => { - Box::new(BasicPageScheduler::new_non_nullable( - decoder_from_array_encoding(no_nulls.values.as_ref().unwrap(), buffers), - )) - } + pb::nullable::Nullability::NoNulls(no_nulls) => Box::new( + BasicPageScheduler::new_non_nullable(decoder_from_array_encoding( + no_nulls.values.as_ref().unwrap(), + buffers, + data_type, + )), + ), pb::nullable::Nullability::SomeNulls(some_nulls) => { Box::new(BasicPageScheduler::new_nullable( - decoder_from_array_encoding(some_nulls.validity.as_ref().unwrap(), buffers), - decoder_from_array_encoding(some_nulls.values.as_ref().unwrap(), buffers), + decoder_from_array_encoding( + some_nulls.validity.as_ref().unwrap(), + buffers, + data_type, + ), + decoder_from_array_encoding( + some_nulls.values.as_ref().unwrap(), + buffers, + data_type, + ), )) } pb::nullable::Nullability::AllNulls(_) => { @@ -102,7 +115,7 @@ pub fn decoder_from_array_encoding( pb::array_encoding::ArrayEncoding::Flat(flat) => get_buffer_decoder(flat, buffers), pb::array_encoding::ArrayEncoding::FixedSizeList(fixed_size_list) => { let item_encoding = fixed_size_list.items.as_ref().unwrap(); - let item_scheduler = decoder_from_array_encoding(item_encoding, buffers); + let item_scheduler = decoder_from_array_encoding(item_encoding, buffers, data_type); Box::new(FixedListScheduler::new( item_scheduler, fixed_size_list.dimension, @@ -112,18 +125,26 @@ pub fn decoder_from_array_encoding( // since we know it is a list based on the schema. In the future there may be different ways // of storing the list offsets. pb::array_encoding::ArrayEncoding::List(list) => { - decoder_from_array_encoding(list.offsets.as_ref().unwrap(), buffers) + decoder_from_array_encoding(list.offsets.as_ref().unwrap(), buffers, data_type) } pb::array_encoding::ArrayEncoding::Binary(binary) => { let indices_encoding = binary.indices.as_ref().unwrap(); let bytes_encoding = binary.bytes.as_ref().unwrap(); - let indices_scheduler = decoder_from_array_encoding(indices_encoding, buffers); - let bytes_scheduler = decoder_from_array_encoding(bytes_encoding, buffers); + let indices_scheduler = + decoder_from_array_encoding(indices_encoding, buffers, data_type); + let bytes_scheduler = decoder_from_array_encoding(bytes_encoding, buffers, data_type); + + let offset_type = match data_type { + DataType::LargeBinary | DataType::LargeUtf8 => DataType::Int64, + _ => DataType::Int32, + }; Box::new(BinaryPageScheduler::new( indices_scheduler.into(), bytes_scheduler.into(), + offset_type, + binary.null_adjustment, )) } // Currently there is no way to encode struct nullability and structs are encoded with a "header" column diff --git a/rust/lance-encoding/src/encodings/physical/basic.rs b/rust/lance-encoding/src/encodings/physical/basic.rs index 15fcea00e0..3dccc2cf82 100644 --- a/rust/lance-encoding/src/encodings/physical/basic.rs +++ b/rust/lance-encoding/src/encodings/physical/basic.rs @@ -5,11 +5,12 @@ use std::sync::Arc; use arrow_array::{ArrayRef, BooleanArray}; use arrow_buffer::BooleanBuffer; +use bytes::BytesMut; use futures::{future::BoxFuture, FutureExt}; use log::trace; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, BufferEncoder, EncodedArray, EncodedArrayBuffer}, format::pb, EncodingsIo, @@ -20,21 +21,21 @@ use lance_core::Result; use super::buffers::BitmapBufferEncoder; struct DataDecoders { - validity: Box, - values: Box, + validity: Box, + values: Box, } enum DataNullStatus { // Neither validity nor values All, // Values only - None(Box), + None(Box), // Validity and values Some(DataDecoders), } impl DataNullStatus { - fn values_decoder(&self) -> Option<&dyn PhysicalPageDecoder> { + fn values_decoder(&self) -> Option<&dyn PrimitivePageDecoder> { match self { Self::All => None, Self::Some(decoders) => Some(decoders.values.as_ref()), @@ -121,10 +122,10 @@ impl BasicPageScheduler { impl PageScheduler for BasicPageScheduler { fn schedule_ranges( &self, - ranges: &[std::ops::Range], + ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { let validity_future = match &self.mode { SchedulerNullStatus::None(_) | SchedulerNullStatus::All => None, SchedulerNullStatus::Some(schedulers) => Some(schedulers.validity.schedule_ranges( @@ -157,7 +158,7 @@ impl PageScheduler for BasicPageScheduler { } _ => unreachable!(), }; - Ok(Box::new(BasicPageDecoder { mode }) as Box) + Ok(Box::new(BasicPageDecoder { mode }) as Box) } .boxed() } @@ -167,51 +168,39 @@ struct BasicPageDecoder { mode: DataNullStatus, } -impl PhysicalPageDecoder for BasicPageDecoder { - fn update_capacity( +impl PrimitivePageDecoder for BasicPageDecoder { + fn decode( &self, - rows_to_skip: u32, - num_rows: u32, - buffers: &mut [(u64, bool)], + rows_to_skip: u64, + num_rows: u64, all_null: &mut bool, - ) { - // No need to look at the validity decoder to know the dest buffer size since it is boolean - buffers[0].0 = arrow_buffer::bit_util::ceil(num_rows as usize, 8) as u64; - // The validity buffer is only required if we have some nulls - buffers[0].1 = matches!(self.mode, DataNullStatus::Some(_)); - if let Some(values) = self.mode.values_decoder() { - values.update_capacity(rows_to_skip, num_rows, &mut buffers[1..], all_null); - } else { - *all_null = true; - } - } - - fn decode_into( - &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [bytes::BytesMut], - ) -> Result<()> { - match &self.mode { + ) -> Result> { + let dest_buffers = match &self.mode { DataNullStatus::Some(decoders) => { - decoders - .validity - .decode_into(rows_to_skip, num_rows, &mut dest_buffers[..1])?; - decoders - .values - .decode_into(rows_to_skip, num_rows, &mut dest_buffers[1..])?; + let mut buffers = decoders.validity.decode(rows_to_skip, num_rows, all_null)?; // buffer 0 + let mut values_bytesmut = + decoders.values.decode(rows_to_skip, num_rows, all_null)?; // buffer 1 onwards + + buffers.append(&mut values_bytesmut); + buffers } // Either dest_buffers[0] is empty, in which case these are no-ops, or one of the // other pages needed the buffer, in which case we need to fill our section DataNullStatus::All => { - dest_buffers[0].fill(0); + let buffers = vec![BytesMut::default()]; + *all_null = true; + buffers } DataNullStatus::None(values) => { - dest_buffers[0].fill(1); - values.decode_into(rows_to_skip, num_rows, &mut dest_buffers[1..])?; + let mut dest_buffers = vec![BytesMut::default()]; + + let mut values_bytesmut = values.decode(rows_to_skip, num_rows, all_null)?; + dest_buffers.append(&mut values_bytesmut); + dest_buffers } - } - Ok(()) + }; + + Ok(dest_buffers) } fn num_buffers(&self) -> u32 { diff --git a/rust/lance-encoding/src/encodings/physical/binary.rs b/rust/lance-encoding/src/encodings/physical/binary.rs index 42937c30cd..f593b35386 100644 --- a/rust/lance-encoding/src/encodings/physical/binary.rs +++ b/rust/lance-encoding/src/encodings/physical/binary.rs @@ -5,18 +5,15 @@ use core::panic; use std::sync::Arc; use arrow_array::cast::AsArray; -use arrow_array::types::UInt32Type; -// use arrow::compute::concat; -use arrow_array::{ - builder::{ArrayBuilder, Int32Builder, UInt32Builder, UInt8Builder}, - Array, ArrayRef, Int32Array, UInt32Array, -}; +use arrow_array::types::UInt64Type; +use arrow_array::{Array, ArrayRef}; +use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, ScalarBuffer}; +use bytes::BytesMut; use futures::stream::StreamExt; use futures::{future::BoxFuture, stream::FuturesOrdered, FutureExt}; -// use rand::seq::index; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, EncodedArray}, format::pb, EncodingsIo, @@ -25,36 +22,100 @@ use crate::{ use crate::decoder::LogicalPageDecoder; use crate::encodings::logical::primitive::PrimitiveFieldDecoder; -use arrow_array::PrimitiveArray; +use arrow_array::{PrimitiveArray, UInt64Array, UInt8Array}; use arrow_schema::DataType; use lance_core::Result; -use std::ops::Deref; + +struct IndicesNormalizer { + indices: Vec, + validity: BooleanBufferBuilder, + null_adjustment: u64, +} + +impl IndicesNormalizer { + fn new(num_rows: u64, null_adjustment: u64) -> Self { + let mut indices = Vec::with_capacity(num_rows as usize); + indices.push(0); + Self { + indices, + validity: BooleanBufferBuilder::new(num_rows as usize), + null_adjustment, + } + } + + fn normalize(&self, val: u64) -> (bool, u64) { + if val >= self.null_adjustment { + (false, val - self.null_adjustment) + } else { + (true, val) + } + } + + fn extend(&mut self, new_indices: &PrimitiveArray, is_start: bool) { + let mut last = *self.indices.last().unwrap(); + if is_start { + let (is_valid, val) = self.normalize(new_indices.value(0)); + self.indices.push(val); + self.validity.append(is_valid); + last += val; + } + let mut prev = self.normalize(*new_indices.values().first().unwrap()).1; + for w in new_indices.values().windows(2) { + let (is_valid, val) = self.normalize(w[1]); + let next = val - prev + last; + self.indices.push(next); + self.validity.append(is_valid); + prev = val; + last = next; + } + } + + fn into_parts(mut self) -> (Vec, BooleanBuffer) { + (self.indices, self.validity.finish()) + } +} #[derive(Debug)] pub struct BinaryPageScheduler { indices_scheduler: Arc, bytes_scheduler: Arc, + offsets_type: DataType, + null_adjustment: u64, } impl BinaryPageScheduler { pub fn new( indices_scheduler: Arc, bytes_scheduler: Arc, + offsets_type: DataType, + null_adjustment: u64, ) -> Self { Self { indices_scheduler, bytes_scheduler, + offsets_type, + null_adjustment, } } } +impl BinaryPageScheduler { + fn decode_indices(decoder: Arc, num_rows: u64) -> Result { + let mut primitive_wrapper = + PrimitiveFieldDecoder::new_from_data(decoder, DataType::UInt64, num_rows); + let drained_task = primitive_wrapper.drain(num_rows)?; + let indices_decode_task = drained_task.task; + indices_decode_task.decode() + } +} + impl PageScheduler for BinaryPageScheduler { fn schedule_ranges( &self, - ranges: &[std::ops::Range], + ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { // ranges corresponds to row ranges that the user wants to fetch. // if user wants row range a..b // Case 1: if a != 0, we need indices a-1..b to decode @@ -68,22 +129,25 @@ impl PageScheduler for BinaryPageScheduler { 0..(range.end) } }) - .collect::>>(); + .collect::>>(); + + let num_rows = ranges.iter().map(|r| r.end - r.start).sum::(); - let mut futures_ordered = FuturesOrdered::new(); - for range in indices_ranges.iter() { - let indices_page_decoder = + let mut futures_ordered = indices_ranges + .iter() + .map(|range| { self.indices_scheduler - .schedule_ranges(&[range.clone()], scheduler, top_level_row); - futures_ordered.push_back(indices_page_decoder); - } + .schedule_ranges(&[range.clone()], scheduler, top_level_row) + }) + .collect::>(); let ranges = ranges.to_vec(); let copy_scheduler = scheduler.clone(); let copy_bytes_scheduler = self.bytes_scheduler.clone(); - let copy_indices_ranges = indices_ranges.to_vec(); + let null_adjustment = self.null_adjustment; + let offsets_type = self.offsets_type.clone(); - async move { + tokio::spawn(async move { // For the following data: // "abcd", "hello", "abcd", "apple", "hello", "abcd" // 4, 9, 13, 18, 23, 27 @@ -96,91 +160,67 @@ impl PageScheduler for BinaryPageScheduler { // These are the normalized offsets stored in decoded_indices // Rest of the workflow is continued later in BinaryPageDecoder - let mut builder = UInt32Builder::new(); + let mut indices_builder = IndicesNormalizer::new(num_rows, null_adjustment); let mut bytes_ranges = Vec::new(); let mut curr_range_idx = 0; - let mut last = 0; while let Some(indices_page_decoder) = futures_ordered.next().await { - let indices: Arc = Arc::from(indices_page_decoder?); + let decoder = Arc::from(indices_page_decoder?); // Build and run decode task for offsets - let curr_indices_range = copy_indices_ranges[curr_range_idx].clone(); + let curr_indices_range = indices_ranges[curr_range_idx].clone(); let curr_row_range = ranges[curr_range_idx].clone(); let indices_num_rows = curr_indices_range.end - curr_indices_range.start; - let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data( - indices, - DataType::UInt32, - indices_num_rows, - ); - let drained_task = primitive_wrapper.drain(indices_num_rows)?; - let indices_decode_task = drained_task.task; - let decoded_part = indices_decode_task.decode()?; - - let indices_array = decoded_part.as_primitive::(); - let mut indices_vec = indices_array.values().to_vec(); - - // Pad a zero at the start if the first row is requested - // This is because the offsets do not start with 0 by default - if curr_row_range.start == 0 { - indices_vec.insert(0, 0); - } - - // Normalize the indices as described above - let normalized_indices: PrimitiveArray = indices_vec - .iter() - .map(|x| x - indices_vec[0] + last) - .collect(); - last = normalized_indices.value(normalized_indices.len() - 1); - let normalized_vec = normalized_indices.values().to_vec(); - - // The first vector to be normalized should not have the leading zero removed - let truncated_vec = if curr_range_idx == 0 { - normalized_vec.as_slice() - } else { - &normalized_vec[1..] - }; - builder.append_slice(truncated_vec); + let indices = Self::decode_indices(decoder, indices_num_rows)?; + let indices = indices.as_primitive::(); - // get bytes range from the index range - let bytes_range = if curr_row_range.start != 0 { - indices_array.value(0)..indices_array.value(indices_array.len() - 1) + let first = if curr_row_range.start == 0 { + 0 } else { - 0..indices_array.value(indices_array.len() - 1) + indices_builder + .normalize(*indices.values().first().unwrap()) + .1 }; + let last = indices_builder + .normalize(*indices.values().last().unwrap()) + .1; + bytes_ranges.push(first..last); + + indices_builder.extend(indices, curr_row_range.start == 0); - bytes_ranges.push(bytes_range); curr_range_idx += 1; } - let decoded_indices = Arc::new(builder.finish()); - - let bytes_ranges_slice = bytes_ranges.as_slice(); + let (indices, validity) = indices_builder.into_parts(); + let decoded_indices = UInt64Array::from(indices); // Schedule the bytes for decoding - let bytes_page_decoder = copy_bytes_scheduler.schedule_ranges( - bytes_ranges_slice, - ©_scheduler, - top_level_row, - ); + let bytes_page_decoder = + copy_bytes_scheduler.schedule_ranges(&bytes_ranges, ©_scheduler, top_level_row); - let bytes_decoder: Box = bytes_page_decoder.await?; + let bytes_decoder: Box = bytes_page_decoder.await?; Ok(Box::new(BinaryPageDecoder { decoded_indices, + validity, + offsets_type, bytes_decoder, - }) as Box) - } + }) as Box) + }) + // Propagate join panic + .map(|join_handle| join_handle.unwrap()) .boxed() } } struct BinaryPageDecoder { - decoded_indices: Arc, - bytes_decoder: Box, + decoded_indices: UInt64Array, + offsets_type: DataType, + validity: BooleanBuffer, + bytes_decoder: Box, } -impl PhysicalPageDecoder for BinaryPageDecoder { +impl PrimitivePageDecoder for BinaryPageDecoder { // Continuing the example from BinaryPageScheduler // Suppose batch_size = 2. Then first, rows_to_skip=0, num_rows=2 // Need to scan 2 rows @@ -188,79 +228,95 @@ impl PhysicalPageDecoder for BinaryPageDecoder { // Allocate 8 bytes capacity. // Next rows_to_skip=2, num_rows=1 // Skip 8 bytes. Allocate 5 bytes capacity. - fn update_capacity( - &self, - rows_to_skip: u32, - num_rows: u32, - buffers: &mut [(u64, bool)], - all_null: &mut bool, - ) { - let offsets = self - .decoded_indices - .as_any() - .downcast_ref::() - .unwrap(); - - // 32 bits or 4 bytes per value. - buffers[0].0 = (num_rows as u64) * 4; - buffers[0].1 = true; - - let bytes_to_skip = offsets.value(rows_to_skip as usize); - let num_bytes = offsets.value((rows_to_skip + num_rows) as usize) - bytes_to_skip; - - self.bytes_decoder - .update_capacity(bytes_to_skip, num_bytes, &mut buffers[1..], all_null); - } - - // Continuing from update_capacity: - // When rows_to_skip=2, num_rows=1 + // // The normalized offsets are [0, 4, 8, 13] // We only need [8, 13] to decode in this case. // These need to be normalized in order to build the string later // So return [0, 5] - fn decode_into( + fn decode( &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [bytes::BytesMut], - ) -> Result<()> { - let offsets = self + rows_to_skip: u64, + num_rows: u64, + all_null: &mut bool, + ) -> Result> { + // Buffers[0] == validity buffer + // Buffers[1] == offsets buffer + // Buffers[2] == null buffer // TODO: Micro-optimization, can we get rid of this? Doesn't hurt much though + // This buffer is always empty since bytes are not allowed to contain nulls + // Buffers[3] == bytes buffer + + // STEP 1: validity buffer + let target_validity = self + .validity + .slice(rows_to_skip as usize, num_rows as usize); + let has_nulls = target_validity.count_set_bits() < target_validity.len(); + + let validity_buffer = if has_nulls { + let num_validity_bits = arrow_buffer::bit_util::ceil(num_rows as usize, 8); + let mut validity_buffer = BytesMut::with_capacity(num_validity_bits); + + if rows_to_skip == 0 { + validity_buffer.extend_from_slice(target_validity.inner().as_slice()); + } else { + // Need to copy the buffer because there may be a bit offset in first byte + let target_validity = BooleanBuffer::from_iter(target_validity.iter()); + validity_buffer.extend_from_slice(target_validity.inner().as_slice()); + } + validity_buffer + } else { + BytesMut::new() + }; + + // STEP 2: offsets buffer + // Currently we always do a copy here, we need to cast to the appropriate type + // and we go ahead and normalize so the starting offset is 0 (though we could skip + // this) + let bytes_per_offset = match self.offsets_type { + DataType::Int32 => 4, + DataType::Int64 => 8, + _ => panic!("Unsupported offsets type"), + }; + + let target_offsets = self .decoded_indices - .as_any() - .downcast_ref::() - .unwrap(); - - let bytes_to_skip = offsets.value(rows_to_skip as usize); - let num_bytes = offsets.value((rows_to_skip + num_rows) as usize) - bytes_to_skip; - - let target_offsets = offsets.slice( - rows_to_skip.try_into().unwrap(), - (num_rows + 1).try_into().unwrap(), - ); + .slice(rows_to_skip as usize, (num_rows + 1) as usize); - // Normalize offsets + // Normalize and cast (TODO: could fuse these into one pass for micro-optimization) let target_vec = target_offsets.values(); - let normalized_array: PrimitiveArray = - target_vec.iter().map(|x| x - target_vec[0]).collect(); - let normalized_values = normalized_array.values(); - - let byte_slice = normalized_values.inner().deref(); + let start = target_vec[0]; + let offsets_buffer = + match bytes_per_offset { + 4 => ScalarBuffer::from_iter(target_vec.iter().map(|x| (x - start) as i32)) + .into_inner(), + 8 => ScalarBuffer::from_iter(target_vec.iter().map(|x| (x - start) as i64)) + .into_inner(), + _ => panic!("Unsupported offsets type"), + }; + // TODO: This forces a second copy, which is unfortunate, try and remove in the future + let offsets_buf = BytesMut::from(offsets_buffer.as_slice()); + + let bytes_to_skip = self.decoded_indices.value(rows_to_skip as usize); + let num_bytes = self + .decoded_indices + .value((rows_to_skip + num_rows) as usize) + - bytes_to_skip; - // copy target_offsets into dest_buffers[0] - dest_buffers[0].extend_from_slice(byte_slice); + let mut output_buffers = vec![validity_buffer, offsets_buf]; - // Copy decoded bytes into dest_buffers[1..] + // Copy decoded bytes into dest_buffers[2..] // Currently an empty null buffer is the first one // The actual bytes are in the second buffer - // Including the indices this results in 3 buffers in total - self.bytes_decoder - .decode_into(bytes_to_skip, num_bytes, &mut dest_buffers[1..])?; + // Including the indices this results in 4 buffers in total + output_buffers.extend( + self.bytes_decoder + .decode(bytes_to_skip, num_bytes, all_null)?, + ); - Ok(()) + Ok(output_buffers) } fn num_buffers(&self) -> u32 { - self.bytes_decoder.num_buffers() + 1 + self.bytes_decoder.num_buffers() + 2 } } @@ -286,149 +342,302 @@ impl BinaryEncoder { // Strings are a vector of arrays corresponding to each record batch // Zero offset is removed from the start of the offsets array // The indices array is computed across all arrays in the vector -fn get_indices_from_string_arrays(arrays: &[ArrayRef]) -> ArrayRef { - let mut indices_builder = Int32Builder::new(); - let mut last_offset = 0; - arrays.iter().for_each(|arr| { - let string_arr = arrow_array::cast::as_string_array(arr); - let offsets = string_arr.offsets().inner(); - let mut offsets = offsets.slice(1, offsets.len() - 1).to_vec(); - - if indices_builder.len() == 0 { - last_offset = offsets[offsets.len() - 1]; +fn get_indices_from_string_arrays(arrays: &[ArrayRef]) -> (ArrayRef, u64) { + let num_rows = arrays.iter().map(|arr| arr.len()).sum::(); + let mut indices = Vec::with_capacity(num_rows); + let mut last_offset = 0_u64; + for array in arrays { + if let Some(array) = array.as_string_opt::() { + let offsets = array.offsets().inner(); + indices.extend(offsets.windows(2).map(|w| { + let strlen = (w[1] - w[0]) as u64; + let off = strlen + last_offset; + last_offset = off; + off + })); + } else if let Some(array) = array.as_string_opt::() { + let offsets = array.offsets().inner(); + indices.extend(offsets.windows(2).map(|w| { + let strlen = (w[1] - w[0]) as u64; + let off = strlen + last_offset; + last_offset = off; + off + })); + } else if let Some(array) = array.as_binary_opt::() { + let offsets = array.offsets().inner(); + indices.extend(offsets.windows(2).map(|w| { + let strlen = (w[1] - w[0]) as u64; + let off = strlen + last_offset; + last_offset = off; + off + })); + } else if let Some(array) = array.as_binary_opt::() { + let offsets = array.offsets().inner(); + indices.extend(offsets.windows(2).map(|w| { + let strlen = (w[1] - w[0]) as u64; + let off = strlen + last_offset; + last_offset = off; + off + })); } else { - offsets = offsets - .iter() - .map(|offset| offset + last_offset) - .collect::>(); - last_offset = offsets[offsets.len() - 1]; + panic!("Array is not a string array"); } + } + let last_offset = *indices.last().expect("Indices array is empty"); + // 8 exabytes in a single array seems unlikely but...just in case + assert!( + last_offset < u64::MAX / 2, + "Indices array with strings up to 2^63 is too large for this encoding" + ); + let null_adjustment: u64 = *indices.last().expect("Indices array is empty") + 1; + + let mut indices_offset = 0; + for array in arrays { + if let Some(nulls) = array.nulls() { + let indices_slice = &mut indices[indices_offset..indices_offset + array.len()]; + indices_slice + .iter_mut() + .zip(nulls.iter()) + .for_each(|(index, is_valid)| { + if !is_valid { + *index += null_adjustment; + } + }); + } + indices_offset += array.len(); + } - let new_int_arr = Int32Array::from(offsets); - indices_builder.append_slice(new_int_arr.values()); - }); - - Arc::new(indices_builder.finish()) as ArrayRef + (Arc::new(UInt64Array::from(indices)), null_adjustment) } // Bytes computed across all string arrays, similar to indices above -fn get_bytes_from_string_arrays(arrays: &[ArrayRef]) -> ArrayRef { - let mut bytes_builder = UInt8Builder::new(); - arrays.iter().for_each(|arr| { - let string_arr = arrow_array::cast::as_string_array(arr); - let values = string_arr.values(); - bytes_builder.append_slice(values); - }); - - Arc::new(bytes_builder.finish()) as ArrayRef +fn get_bytes_from_string_arrays(arrays: &[ArrayRef]) -> Vec { + arrays + .iter() + .map(|arr| { + let (values_buffer, start, stop) = if let Some(arr) = arr.as_string_opt::() { + ( + arr.values(), + arr.offsets()[0] as usize, + arr.offsets()[arr.len()] as usize, + ) + } else if let Some(arr) = arr.as_string_opt::() { + ( + arr.values(), + arr.offsets()[0] as usize, + arr.offsets()[arr.len()] as usize, + ) + } else if let Some(arr) = arr.as_binary_opt::() { + ( + arr.values(), + arr.offsets()[0] as usize, + arr.offsets()[arr.len()] as usize, + ) + } else if let Some(arr) = arr.as_binary_opt::() { + ( + arr.values(), + arr.offsets()[0] as usize, + arr.offsets()[arr.len()] as usize, + ) + } else { + panic!("Array is not a string / binary array"); + }; + let values = ScalarBuffer::new(values_buffer.clone(), start, stop - start); + Arc::new(UInt8Array::new(values, None)) as ArrayRef + }) + .collect() } impl ArrayEncoder for BinaryEncoder { fn encode(&self, arrays: &[ArrayRef], buffer_index: &mut u32) -> Result { - let (null_count, _row_count) = arrays - .iter() - .map(|arr| (arr.null_count() as u32, arr.len() as u32)) - .fold((0, 0), |acc, val| (acc.0 + val.0, acc.1 + val.1)); + let (index_array, null_adjustment) = get_indices_from_string_arrays(arrays); + let encoded_indices = self.indices_encoder.encode(&[index_array], buffer_index)?; + + let byte_arrays = get_bytes_from_string_arrays(arrays); + let encoded_bytes = self.bytes_encoder.encode(&byte_arrays, buffer_index)?; + + let mut encoded_buffers = encoded_indices.buffers; + encoded_buffers.extend(encoded_bytes.buffers); + + Ok(EncodedArray { + buffers: encoded_buffers, + encoding: pb::ArrayEncoding { + array_encoding: Some(pb::array_encoding::ArrayEncoding::Binary(Box::new( + pb::Binary { + indices: Some(Box::new(encoded_indices.encoding)), + bytes: Some(Box::new(encoded_bytes.encoding)), + null_adjustment, + }, + ))), + }, + }) + } +} - if null_count != 0 { - panic!("Data contains null values, not currently supported for binary data.") - } else { - let index_array = get_indices_from_string_arrays(arrays); - let encoded_indices = self.indices_encoder.encode(&[index_array], buffer_index)?; - - let byte_array = get_bytes_from_string_arrays(arrays); - let encoded_bytes = self.bytes_encoder.encode(&[byte_array], buffer_index)?; - - let mut encoded_buffers = encoded_indices.buffers; - encoded_buffers.extend(encoded_bytes.buffers); - - Ok(EncodedArray { - buffers: encoded_buffers, - encoding: pb::ArrayEncoding { - array_encoding: Some(pb::array_encoding::ArrayEncoding::Binary(Box::new( - pb::Binary { - indices: Some(Box::new(encoded_indices.encoding)), - bytes: Some(Box::new(encoded_bytes.encoding)), - }, - ))), - }, - }) +#[cfg(test)] +pub mod tests { + + use arrow_array::{ + builder::{LargeStringBuilder, StringBuilder}, + ArrayRef, LargeStringArray, StringArray, UInt64Array, + }; + use arrow_schema::{DataType, Field}; + use std::{sync::Arc, vec}; + + use crate::testing::{ + check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases, + }; + + use super::get_indices_from_string_arrays; + + #[test_log::test(tokio::test)] + async fn test_utf8() { + let field = Field::new("", DataType::Utf8, false); + check_round_trip_encoding_random(field).await; + } + + #[test] + fn test_encode_indices_stitches_offsets() { + // Given two string arrays we might have offsets [5, 10, 15] and [0, 3, 7] + // + // We need to stitch them to [0, 5, 10, 13, 17] + let string_array1 = StringArray::from(vec![Some("abcde"), Some("abcde"), Some("abcde")]); + let string_array1 = Arc::new(string_array1.slice(1, 2)); + let string_array2 = Arc::new(StringArray::from(vec![Some("abc"), Some("abcd")])); + let (offsets, null_adjustment) = + get_indices_from_string_arrays(&[string_array1, string_array2]); + + let expected = Arc::new(UInt64Array::from(vec![5, 10, 13, 17])) as ArrayRef; + assert_eq!(&offsets, &expected); + assert_eq!(null_adjustment, 18); + } + + #[test] + fn test_encode_indices_adjusts_nulls() { + // Null entries in string arrays should be adjusted + let string_array1 = Arc::new(StringArray::from(vec![None, Some("foo")])); + let string_array2 = Arc::new(StringArray::from(vec![Some("foo"), None])); + let string_array3 = Arc::new(StringArray::from(vec![None as Option<&str>, None])); + let (offsets, null_adjustment) = + get_indices_from_string_arrays(&[string_array1, string_array2, string_array3]); + + let expected = Arc::new(UInt64Array::from(vec![7, 3, 6, 13, 13, 13])) as ArrayRef; + assert_eq!(&offsets, &expected); + assert_eq!(null_adjustment, 7); + } + + #[test] + fn test_encode_indices_string_types() { + let string_array = Arc::new(LargeStringArray::from(vec![Some("foo")])); + let large_string_array = Arc::new(LargeStringArray::from(vec![Some("foo")])); + let binary_array = Arc::new(LargeStringArray::from(vec![Some("foo")])); + let large_binary_array = Arc::new(LargeStringArray::from(vec![Some("foo")])); + + for arr in [ + string_array, + large_string_array, + binary_array, + large_binary_array, + ] { + let (offsets, null_adjustment) = get_indices_from_string_arrays(&[arr]); + let expected = Arc::new(UInt64Array::from(vec![3])) as ArrayRef; + assert_eq!(&offsets, &expected); + assert_eq!(null_adjustment, 4); } - // Currently not handling all null cases in this array encoder. - // TODO: Separate behavior for all null rows vs some null rows - // else if null_count == row_count { - // let nullability = pb::nullable::Nullability::AllNulls(pb::nullable::AllNull {}); - - // Ok(EncodedArray { - // buffers: vec![], - // encoding: pb::ArrayEncoding { - // array_encoding: Some(pb::array_encoding::ArrayEncoding::Nullable(Box::new( - // pb::Nullable { - // nullability: Some(nullability), - // }, - // ))), - // }, - // }) - // } - - // let arr_encoding = self.values_encoder.encode(arrays, buffer_index)?; - // let encoding = pb::nullable::Nullability::NoNulls(Box::new(pb::nullable::NoNull { - // values: Some(Box::new(arr_encoding.encoding)), - // })); - // (arr_encoding.buffers, encoding) - // } else if null_count == row_count { - // let encoding = pb::nullable::Nullability::AllNulls(pb::nullable::AllNull {}); - // (vec![], encoding) - // } else { - // let validity_as_arrays = arrays - // .iter() - // .map(|arr| { - // if let Some(nulls) = arr.nulls() { - // Arc::new(BooleanArray::new(nulls.inner().clone(), None)) as ArrayRef - // } else { - // let buff = BooleanBuffer::new_set(arr.len()); - // Arc::new(BooleanArray::new(buff, None)) as ArrayRef - // } - // }) - // .collect::>(); - - // let validity_buffer_index = *buffer_index; - // *buffer_index += 1; - // let validity = BitmapBufferEncoder::default().encode(&validity_as_arrays)?; - // let validity_encoding = Box::new(pb::ArrayEncoding { - // array_encoding: Some(pb::array_encoding::ArrayEncoding::Flat(pb::Flat { - // bits_per_value: 1, - // buffer: Some(pb::Buffer { - // buffer_index: validity_buffer_index, - // buffer_type: pb::buffer::BufferType::Page as i32, - // }), - // compression: None, - // })), - // }); - - // let arr_encoding = self.values_encoder.encode(arrays, buffer_index)?; - // let encoding = pb::nullable::Nullability::SomeNulls(Box::new(pb::nullable::SomeNull { - // validity: Some(validity_encoding), - // values: Some(Box::new(arr_encoding.encoding)), - // })); - - // let mut buffers = arr_encoding.buffers; - // buffers.push(EncodedArrayBuffer { - // parts: validity.parts, - // index: validity_buffer_index, - // }); - // (buffers, encoding) - // }; - - // Ok(EncodedArray { - // buffers, - // encoding: pb::ArrayEncoding { - // array_encoding: Some(pb::array_encoding::ArrayEncoding::Nullable(Box::new( - // pb::Nullable { - // nullability: Some(nullability), - // }, - // ))), - // }, - // }) + } + + #[test_log::test(tokio::test)] + async fn test_binary() { + let field = Field::new("", DataType::Binary, false); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_large_binary() { + let field = Field::new("", DataType::LargeBinary, true); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_large_utf8() { + let field = Field::new("", DataType::LargeUtf8, true); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_simple_utf8() { + let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_range(0..3) + .with_range(1..3) + .with_indices(vec![1, 3]); + check_round_trip_encoding_of_data(vec![Arc::new(string_array)], &test_cases).await; + } + + #[test_log::test(tokio::test)] + async fn test_sliced_utf8() { + let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]); + let string_array = string_array.slice(1, 3); + + let test_cases = TestCases::default() + .with_range(0..1) + .with_range(0..2) + .with_range(1..2); + check_round_trip_encoding_of_data(vec![Arc::new(string_array)], &test_cases).await; + } + + #[test_log::test(tokio::test)] + async fn test_empty_strings() { + // Scenario 1: Some strings are empty + + let values = [Some("abc"), Some(""), None]; + // Test empty list at beginning, middle, and end + for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] { + let mut string_builder = StringBuilder::new(); + for idx in order { + string_builder.append_option(values[idx]); + } + let string_array = Arc::new(string_builder.finish()); + let test_cases = TestCases::default() + .with_indices(vec![1]) + .with_indices(vec![0]) + .with_indices(vec![2]); + check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases).await; + let test_cases = test_cases.with_batch_size(1); + check_round_trip_encoding_of_data(vec![string_array], &test_cases).await; + } + + // Scenario 2: All strings are empty + + // When encoding an array of empty strings there are no bytes to encode + // which is strange and we want to ensure we handle it + let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")])); + + let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]); + check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases).await; + let test_cases = test_cases.with_batch_size(1); + check_round_trip_encoding_of_data(vec![string_array], &test_cases).await; + } + + #[test_log::test(tokio::test)] + #[ignore] // This test is quite slow in debug mode + async fn test_jumbo_string() { + // This is an overflow test. We have a list of lists where each list + // has 1Mi items. We encode 5000 of these lists and so we have over 4Gi in the + // offsets range + let mut string_builder = LargeStringBuilder::new(); + // a 1 MiB string + let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0')); + for _ in 0..5000 { + string_builder.append_option(Some(&giant_string)); + } + let giant_array = Arc::new(string_builder.finish()) as ArrayRef; + let arrs = vec![giant_array]; + + // // We can't validate because our validation relies on concatenating all input arrays + let test_cases = TestCases::default().without_validation(); + check_round_trip_encoding_of_data(arrs, &test_cases).await; } } diff --git a/rust/lance-encoding/src/encodings/physical/bitmap.rs b/rust/lance-encoding/src/encodings/physical/bitmap.rs index b95be7580d..eaf9dffb84 100644 --- a/rust/lance-encoding/src/encodings/physical/bitmap.rs +++ b/rust/lance-encoding/src/encodings/physical/bitmap.rs @@ -11,7 +11,7 @@ use lance_core::Result; use log::trace; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, EncodingsIo, }; @@ -33,19 +33,19 @@ impl DenseBitmapScheduler { impl PageScheduler for DenseBitmapScheduler { fn schedule_ranges( &self, - ranges: &[Range], + ranges: &[Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { let mut min = u64::MAX; let mut max = 0; let chunk_reqs = ranges .iter() .map(|range| { debug_assert_ne!(range.start, range.end); - let start = self.buffer_offset + range.start as u64 / 8; + let start = self.buffer_offset + range.start / 8; let bit_offset = range.start % 8; - let end = self.buffer_offset + range.end.div_ceil(8) as u64; + let end = self.buffer_offset + range.end.div_ceil(8); let byte_range = start..end; min = min.min(start); max = max.max(end); @@ -76,7 +76,7 @@ impl PageScheduler for DenseBitmapScheduler { length, }) .collect::>(); - Ok(Box::new(BitmapDecoder { chunks }) as Box) + Ok(Box::new(BitmapDecoder { chunks }) as Box) } .boxed() } @@ -84,32 +84,24 @@ impl PageScheduler for DenseBitmapScheduler { struct BitmapData { data: Bytes, - bit_offset: u32, - length: u32, + bit_offset: u64, + length: u64, } struct BitmapDecoder { chunks: Vec, } -impl PhysicalPageDecoder for BitmapDecoder { - fn update_capacity( +impl PrimitivePageDecoder for BitmapDecoder { + fn decode( &self, - _rows_to_skip: u32, - num_rows: u32, - buffers: &mut [(u64, bool)], + rows_to_skip: u64, + num_rows: u64, _all_null: &mut bool, - ) { - buffers[0].0 = arrow_buffer::bit_util::ceil(num_rows as usize, 8) as u64; - buffers[0].1 = true; - } + ) -> Result> { + let num_bytes = arrow_buffer::bit_util::ceil(num_rows as usize, 8); + let mut dest_buffers = vec![BytesMut::with_capacity(num_bytes)]; - fn decode_into( - &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [BytesMut], - ) -> Result<()> { let mut rows_to_skip = rows_to_skip; let mut dest_builder = BooleanBufferBuilder::new(num_rows as usize); @@ -146,7 +138,7 @@ impl PhysicalPageDecoder for BitmapDecoder { // // It's a moot point at the moment since we don't support page bridging dest_buffers[0].copy_from_slice(bool_buffer.as_slice()); - Ok(()) + Ok(dest_buffers) } fn num_buffers(&self) -> u32 { @@ -157,9 +149,9 @@ impl PhysicalPageDecoder for BitmapDecoder { #[cfg(test)] mod tests { use arrow_schema::{DataType, Field}; - use bytes::{Bytes, BytesMut}; + use bytes::Bytes; - use crate::decoder::PhysicalPageDecoder; + use crate::decoder::PrimitivePageDecoder; use crate::encodings::physical::bitmap::BitmapData; use crate::testing::check_round_trip_encoding_random; @@ -189,8 +181,8 @@ mod tests { }, ], }; - let mut dest = vec![BytesMut::with_capacity(1)]; - let result = decoder.decode_into(5, 1, &mut dest); + + let result = decoder.decode(5, 1, &mut false); assert!(result.is_ok()); } } diff --git a/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs b/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs index 0d69f95aab..b9308b0af2 100644 --- a/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs +++ b/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs @@ -4,12 +4,13 @@ use std::sync::Arc; use arrow_array::{cast::AsArray, ArrayRef}; +use bytes::BytesMut; use futures::{future::BoxFuture, FutureExt}; use lance_core::Result; use log::trace; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, EncodedArray}, format::pb, EncodingsIo, @@ -36,13 +37,13 @@ impl FixedListScheduler { impl PageScheduler for FixedListScheduler { fn schedule_ranges( &self, - ranges: &[std::ops::Range], + ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { let expanded_ranges = ranges .iter() - .map(|range| (range.start * self.dimension)..(range.end * self.dimension)) + .map(|range| (range.start * self.dimension as u64)..(range.end * self.dimension as u64)) .collect::>(); trace!( "Expanding {} fsl ranges across {}..{} to item ranges across {}..{}", @@ -60,43 +61,28 @@ impl PageScheduler for FixedListScheduler { let items_decoder = inner_page_decoder.await?; Ok(Box::new(FixedListDecoder { items_decoder, - dimension, - }) as Box) + dimension: dimension as u64, + }) as Box) } .boxed() } } pub struct FixedListDecoder { - items_decoder: Box, - dimension: u32, + items_decoder: Box, + dimension: u64, } -impl PhysicalPageDecoder for FixedListDecoder { - fn update_capacity( +impl PrimitivePageDecoder for FixedListDecoder { + fn decode( &self, - rows_to_skip: u32, - num_rows: u32, - buffers: &mut [(u64, bool)], + rows_to_skip: u64, + num_rows: u64, all_null: &mut bool, - ) { - let rows_to_skip = rows_to_skip * self.dimension; - let num_rows = num_rows * self.dimension; - self.items_decoder - .update_capacity(rows_to_skip, num_rows, buffers, all_null); - } - - fn decode_into( - &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [bytes::BytesMut], - ) -> Result<()> { + ) -> Result> { let rows_to_skip = rows_to_skip * self.dimension; let num_rows = num_rows * self.dimension; - self.items_decoder - .decode_into(rows_to_skip, num_rows, dest_buffers)?; - Ok(()) + self.items_decoder.decode(rows_to_skip, num_rows, all_null) } fn num_buffers(&self) -> u32 { diff --git a/rust/lance-encoding/src/encodings/physical/value.rs b/rust/lance-encoding/src/encodings/physical/value.rs index 3e8407d6fa..fb9aa88c73 100644 --- a/rust/lance-encoding/src/encodings/physical/value.rs +++ b/rust/lance-encoding/src/encodings/physical/value.rs @@ -3,7 +3,7 @@ use arrow_array::ArrayRef; use arrow_schema::DataType; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use futures::{future::BoxFuture, FutureExt}; use lance_arrow::DataTypeExt; use log::trace; @@ -13,7 +13,7 @@ use std::ops::Range; use std::sync::{Arc, Mutex}; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, BufferEncoder, EncodedArray, EncodedArrayBuffer}, format::pb, EncodingsIo, @@ -82,17 +82,17 @@ impl ValuePageScheduler { impl PageScheduler for ValuePageScheduler { fn schedule_ranges( &self, - ranges: &[std::ops::Range], + ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { let (mut min, mut max) = (u64::MAX, 0); let byte_ranges = if self.compression_scheme == CompressionScheme::None { ranges .iter() .map(|range| { - let start = self.buffer_offset + (range.start as u64 * self.bytes_per_value); - let end = self.buffer_offset + (range.end as u64 * self.bytes_per_value); + let start = self.buffer_offset + (range.start * self.bytes_per_value); + let end = self.buffer_offset + (range.end * self.bytes_per_value); min = min.min(start); max = max.max(end); start..end @@ -122,8 +122,8 @@ impl PageScheduler for ValuePageScheduler { ranges .iter() .map(|range| { - let start = (range.start as u64 * bytes_per_value) as usize; - let end = (range.end as u64 * bytes_per_value) as usize; + let start = (range.start * bytes_per_value) as usize; + let end = (range.end * bytes_per_value) as usize; start..end }) .collect::>() @@ -139,7 +139,7 @@ impl PageScheduler for ValuePageScheduler { data: bytes, uncompressed_data: Arc::new(Mutex::new(None)), uncompressed_range_offsets: range_offsets, - }) as Box) + }) as Box) } .boxed() } @@ -203,26 +203,17 @@ impl ValuePageDecoder { } } -impl PhysicalPageDecoder for ValuePageDecoder { - fn update_capacity( +impl PrimitivePageDecoder for ValuePageDecoder { + fn decode( &self, - _rows_to_skip: u32, - num_rows: u32, - buffers: &mut [(u64, bool)], + rows_to_skip: u64, + num_rows: u64, _all_null: &mut bool, - ) { - buffers[0].0 = self.bytes_per_value * num_rows as u64; - buffers[0].1 = true; - } + ) -> Result> { + let mut bytes_to_skip = rows_to_skip * self.bytes_per_value; + let mut bytes_to_take = num_rows * self.bytes_per_value; - fn decode_into( - &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [bytes::BytesMut], - ) -> Result<()> { - let mut bytes_to_skip = rows_to_skip as u64 * self.bytes_per_value; - let mut bytes_to_take = num_rows as u64 * self.bytes_per_value; + let mut dest_buffers = vec![BytesMut::with_capacity(bytes_to_take as usize)]; let dest = &mut dest_buffers[0]; @@ -238,7 +229,7 @@ impl PhysicalPageDecoder for ValuePageDecoder { self.decode_buffer(buf, &mut bytes_to_skip, &mut bytes_to_take, dest); } } - Ok(()) + Ok(dest_buffers) } fn num_buffers(&self) -> u32 { diff --git a/rust/lance-file/src/v2/reader.rs b/rust/lance-file/src/v2/reader.rs index e4383949e5..3fe182a0d1 100644 --- a/rust/lance-file/src/v2/reader.rs +++ b/rust/lance-file/src/v2/reader.rs @@ -7,7 +7,6 @@ use arrow_schema::Schema as ArrowSchema; use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; use bytes::{Bytes, BytesMut}; use futures::{stream::BoxStream, FutureExt, Stream, StreamExt}; -use lance_arrow::DataTypeExt; use lance_encoding::{ decoder::{ BatchDecodeStream, ColumnInfo, DecodeBatchScheduler, DecoderMiddlewareChain, @@ -37,10 +36,7 @@ use crate::{ format::{pb, pbfile, MAGIC, MAJOR_VERSION, MINOR_VERSION_NEXT}, }; -use lance_encoding::encoder::get_str_encoding_type; - use super::io::LanceEncodingsIo; -use arrow_schema::DataType; // For now, we don't use global buffers for anything other than schema. If we // use these later we should make them lazily loaded and then cached once loaded. @@ -459,15 +455,11 @@ impl FileReader { // Helper function for `default_projection` to determine how many columns are occupied // by a lance field. fn default_column_count(field: &Field) -> u32 { - if field.data_type().is_binary_like() { - 2 - } else { - 1 + field - .children - .iter() - .map(Self::default_column_count) - .sum::() - } + 1 + field + .children + .iter() + .map(Self::default_column_count) + .sum::() } // This function is one of the few spots in the reader where we rely on Lance table @@ -529,19 +521,6 @@ impl FileReader { column_infos.push(self.metadata.column_infos[*column_idx].clone()); *column_idx += 1; - if get_str_encoding_type() { - // use str array encoding - if (field.data_type().is_binary_like()) && (field.data_type() != DataType::Utf8) { - // These types are 2 columns in a lance file but a single field id in a lance schema - column_infos.push(self.metadata.column_infos[*column_idx].clone()); - *column_idx += 1; - } - } else if field.data_type().is_binary_like() { - // These types are 2 columns in a lance file but a single field id in a lance schema - column_infos.push(self.metadata.column_infos[*column_idx].clone()); - *column_idx += 1; - } - for child in &field.children { self.collect_columns(child, column_idx, column_infos)?; } @@ -950,12 +929,7 @@ impl EncodedBatchReaderExt for EncodedBatch { data: bytes, num_rows: page_table .first() - .map(|col| { - col.page_infos - .iter() - .map(|page| page.num_rows as u64) - .sum::() - }) + .map(|col| col.page_infos.iter().map(|page| page.num_rows).sum::()) .unwrap_or(0), page_table, schema: Arc::new(schema.clone()), @@ -998,12 +972,7 @@ impl EncodedBatchReaderExt for EncodedBatch { data: bytes, num_rows: page_table .first() - .map(|col| { - col.page_infos - .iter() - .map(|page| page.num_rows as u64) - .sum::() - }) + .map(|col| col.page_infos.iter().map(|page| page.num_rows).sum::()) .unwrap_or(0), page_table, schema: Arc::new(schema.clone()), @@ -1018,7 +987,6 @@ pub mod tests { use arrow_array::{ types::{Float64Type, Int32Type}, RecordBatch, - // StringArray, UInt32Array, }; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema}; use bytes::Bytes; @@ -1030,11 +998,8 @@ pub mod tests { decoder::{decode_batch, DecoderMiddlewareChain, FilterExpression}, encoder::{encode_batch, CoreFieldEncodingStrategy, EncodedBatch}, }; - // use lance_io::object_store::ObjectStore; - // use lance_io::scheduler::ScanScheduler; use lance_io::stream::RecordBatchStream; use log::debug; - // use object_store::path::Path; use crate::v2::{ reader::{EncodedBatchReaderExt, FileReader, ReaderProjection}, @@ -1053,6 +1018,8 @@ pub mod tests { .col("score", array::rand::()) .col("location", array::rand_type(&location_type)) .col("categories", array::rand_type(&categories_type)) + .col("binary", array::rand_type(&DataType::Binary)) + .col("large_bin", array::rand_type(&DataType::LargeBinary)) .into_reader_rows(RowCount::from(1000), BatchCount::from(100)); write_lance_file(reader, fs, FileWriterOptions::default()).await @@ -1450,174 +1417,4 @@ pub mod tests { let buf = file_reader.read_global_buffer(1).await.unwrap(); assert_eq!(buf, test_bytes); } - - // fn test_reading_rangefrom( - // schema: Arc, - // ) -> (lance_io::ReadBatchParams, Vec) { - // let result_batch = RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![3, 4, 5, 6])), - // Arc::new(StringArray::from(vec!["abcd", "apple", "hello", "abcd"])), - // ], - // ) - // .unwrap(); - - // let result_batches = vec![result_batch]; - // let read_params = lance_io::ReadBatchParams::RangeFrom(2..); - - // (read_params, result_batches) - // } - - // fn test_reading_rangeto( - // schema: Arc, - // ) -> (lance_io::ReadBatchParams, Vec) { - // let read_params = lance_io::ReadBatchParams::RangeTo(..4); - // let result_batch = RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), - // Arc::new(StringArray::from(vec!["abcd", "hello", "abcd", "apple"])), - // ], - // ) - // .unwrap(); - - // let result_batches = vec![result_batch]; - - // (read_params, result_batches) - // } - - // fn test_reading_random_indices( - // schema: Arc, - // ) -> (lance_io::ReadBatchParams, Vec) { - // let row_indices_vec = vec![0, 2, 4]; - // let row_indices = UInt32Array::from(row_indices_vec); - // let read_params = lance_io::ReadBatchParams::from(row_indices); - // let result_batch = RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![1, 3, 5])), - // Arc::new(StringArray::from(vec!["abcd", "abcd", "hello"])), - // ], - // ) - // .unwrap(); - - // let result_batches = vec![result_batch]; - - // (read_params, result_batches) - // } - - // fn test_reading_partial_range( - // schema: Arc, - // ) -> (lance_io::ReadBatchParams, Vec) { - // let read_params = lance_io::ReadBatchParams::Range(2..4); - // let result_batch = RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![3, 4])), - // Arc::new(StringArray::from(vec!["abcd", "apple"])), - // ], - // ) - // .unwrap(); - - // let result_batches = vec![result_batch]; - - // (read_params, result_batches) - // } - - // #[tokio::test] - // async fn test_string_array_encoding() { - // // set env var temporarily to test string array encoding - // let _env_guard = EnvVarGuard::new("LANCE_STR_ARRAY_ENCODING", "binary"); - - // let tmp_dir = tempfile::tempdir().unwrap(); - // let tmp_path: String = tmp_dir.path().to_str().unwrap().to_owned(); - // let tmp_path = Path::parse(tmp_path).unwrap(); - // let tmp_path = tmp_path.child("some_file.lance"); - // let obj_store = Arc::new(ObjectStore::local()); - // let writer = obj_store.create(&tmp_path).await.unwrap(); - - // let schema = Arc::new(ArrowSchema::new(vec![ - // Field::new("key", DataType::UInt32, false), - // Field::new("strings", DataType::Utf8, false), - // ])); - - // let batch1 = RecordBatch::try_new( - // schema.clone(), - // vec![ - // Arc::new(UInt32Array::from(vec![1, 2, 3])), - // Arc::new(StringArray::from(vec!["abcd", "hello", "abcd"])), - // ], - // ) - // .unwrap(); - - // let batch2 = RecordBatch::try_new( - // schema.clone(), - // vec![ - // Arc::new(UInt32Array::from(vec![4, 5, 6])), - // Arc::new(StringArray::from(vec!["apple", "hello", "abcd"])), - // ], - // ) - // .unwrap(); - - // let batches = vec![batch1, batch2]; - // let lance_schema = lance_core::datatypes::Schema::try_from(schema.as_ref()).unwrap(); - - // let mut file_writer = FileWriter::try_new( - // writer, - // tmp_path.to_string(), - // lance_schema, - // FileWriterOptions::default(), - // ) - // .unwrap(); - - // for batch in batches.clone() { - // file_writer.write_batch(&batch).await.unwrap(); - // } - - // file_writer.finish().await.unwrap(); - - // let object_store = Arc::new(ObjectStore::local()); - // let fs_scheduler = ScanScheduler::new(object_store.clone(), 8); - // let file_scheduler = fs_scheduler.open_file(&tmp_path).await.unwrap(); - - // let file_reader = - // FileReader::try_open(file_scheduler, None, DecoderMiddlewareChain::default()) - // .await - // .unwrap(); - - // for batch_size in [1, 2, 1024] { - // // Read different types of ranges from the file - // let (read_params, result_batches) = test_reading_rangefrom(schema.clone()); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - - // let (read_params, result_batches) = test_reading_rangeto(schema.clone()); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - - // let (read_params, result_batches) = test_reading_random_indices(schema.clone()); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - - // let (read_params, result_batches) = test_reading_partial_range(schema.clone()); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - - // let read_params = lance_io::ReadBatchParams::RangeFull; - // let result_batches = batches.clone(); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - // } - // } } diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 2f88fa684f..0991ae6e00 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -76,6 +76,10 @@ harness = false name = "pq_dist_table" harness = false +[[bench]] +name = "pq_assignment" +harness = false + [[bench]] name = "hnsw" harness = false diff --git a/rust/lance-index/benches/find_partitions.rs b/rust/lance-index/benches/find_partitions.rs index 9e68602684..dd370128f0 100644 --- a/rust/lance-index/benches/find_partitions.rs +++ b/rust/lance-index/benches/find_partitions.rs @@ -11,7 +11,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; #[cfg(target_os = "linux")] use pprof::criterion::{Output, PProfProfiler}; -use lance_index::vector::ivf::Ivf; +use lance_index::vector::ivf::IvfTransformer; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_seed; @@ -27,14 +27,14 @@ fn bench_partitions(c: &mut Criterion) { let fsl = FixedSizeListArray::try_new_from_values(centroids, DIMENSION as i32).unwrap(); for k in &[1, 10, 50] { - let ivf = Ivf::new(fsl.clone(), DistanceType::L2, vec![]); + let ivf = IvfTransformer::new(fsl.clone(), DistanceType::L2, vec![]); c.bench_function(format!("IVF{},k={},L2", num_centroids, k).as_str(), |b| { b.iter(|| { let _ = ivf.find_partitions(&query, *k); }) }); - let ivf = Ivf::new(fsl.clone(), DistanceType::Cosine, vec![]); + let ivf = IvfTransformer::new(fsl.clone(), DistanceType::Cosine, vec![]); c.bench_function( format!("IVF{},k={},Cosine", num_centroids, k).as_str(), |b| { @@ -45,7 +45,7 @@ fn bench_partitions(c: &mut Criterion) { ); } - let ivf = Ivf::new(fsl.clone(), DistanceType::L2, vec![]); + let ivf = IvfTransformer::new(fsl.clone(), DistanceType::L2, vec![]); let batch = generate_random_array_with_seed::(DIMENSION * 4096, SEED); let fsl = FixedSizeListArray::try_new_from_values(batch, DIMENSION as i32).unwrap(); c.bench_function( diff --git a/rust/lance-index/benches/pq_assignment.rs b/rust/lance-index/benches/pq_assignment.rs new file mode 100644 index 0000000000..bf448633f7 --- /dev/null +++ b/rust/lance-index/benches/pq_assignment.rs @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmark of Building PQ code from Dense Vectors. + +use std::sync::Arc; + +use arrow_array::{types::Float32Type, FixedSizeListArray}; +use criterion::{criterion_group, criterion_main, Criterion}; +use lance_arrow::FixedSizeListArrayExt; +use lance_index::vector::pq::{ProductQuantizer, ProductQuantizerImpl}; +use lance_linalg::distance::DistanceType; +use lance_testing::datagen::generate_random_array_with_seed; + +const PQ: usize = 96; +const DIM: usize = 1536; +const TOTAL: usize = 32 * 1024; + +fn pq_transform(c: &mut Criterion) { + let codebook = Arc::new(generate_random_array_with_seed::( + 256 * DIM, + [88; 32], + )); + + let vectors = generate_random_array_with_seed::(DIM * TOTAL, [3; 32]); + let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + + for dt in [DistanceType::L2, DistanceType::Dot].iter() { + let pq = ProductQuantizerImpl::::new(PQ, 8, DIM, codebook.clone(), *dt); + + c.bench_function(format!("{},{}", dt, TOTAL).as_str(), |b| { + b.iter(|| { + let _ = pq.transform(&fsl).unwrap(); + }) + }); + } +} + +criterion_group!( + name=benches; + config = Criterion::default().significance_level(0.1).sample_size(10); + targets = pq_transform); + +criterion_main!(benches); diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 95492daf5f..18521f9222 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -52,6 +52,9 @@ pub trait Index: Send + Sync + DeepSizeOf { /// Cast to [Index] fn as_index(self: Arc) -> Arc; + /// Cast to [vector::VectorIndex] + fn as_vector_index(self: Arc) -> Result>; + /// Retrieve index statistics as a JSON Value fn statistics(&self) -> Result; diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 39916cab2c..752d56cc6d 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -789,6 +789,13 @@ impl Index for BTreeIndex { self } + fn as_vector_index(self: Arc) -> Result> { + Err(Error::NotSupported { + source: "BTreeIndex is not vector index".into(), + location: location!(), + }) + } + fn index_type(&self) -> IndexType { IndexType::Scalar } diff --git a/rust/lance-index/src/scalar/flat.rs b/rust/lance-index/src/scalar/flat.rs index 2d7ac03bce..842773929a 100644 --- a/rust/lance-index/src/scalar/flat.rs +++ b/rust/lance-index/src/scalar/flat.rs @@ -14,8 +14,9 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_physical_expr::expressions::{in_list, lit, Column}; use deepsize::DeepSizeOf; use lance_core::utils::address::RowAddress; -use lance_core::Result; +use lance_core::{Error, Result}; use roaring::RoaringBitmap; +use snafu::{location, Location}; use crate::{Index, IndexType}; @@ -157,6 +158,13 @@ impl Index for FlatIndex { self } + fn as_vector_index(self: Arc) -> Result> { + Err(Error::NotSupported { + source: "FlatIndex is not vector index".into(), + location: location!(), + }) + } + fn index_type(&self) -> IndexType { IndexType::Scalar } diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index ebaee29f1f..63ad4955f3 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -6,13 +6,16 @@ use std::{collections::HashMap, sync::Arc}; -use arrow_array::{ArrayRef, RecordBatch}; +use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; use arrow_schema::Field; use async_trait::async_trait; +use ivf::storage::IvfModel; use lance_core::{Result, ROW_ID_FIELD}; use lance_io::traits::Reader; use lance_linalg::distance::DistanceType; use lazy_static::lazy_static; +use quantizer::{QuantizationType, Quantizer}; +use v3::subindex::SubIndexType; pub mod bq; pub mod flat; @@ -102,6 +105,7 @@ impl From for pb::VectorMetricType { } /// Vector Index for (Approximate) Nearest Neighbor (ANN) Search. +/// It's always the IVF index, any other index types without partitioning will be treated as IVF with one partition. #[async_trait] #[allow(clippy::redundant_pub_crate)] pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { @@ -125,6 +129,15 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// - Only supports `f32` now. Will add f64/f16 later. async fn search(&self, query: &Query, pre_filter: Arc) -> Result; + fn find_partitions(&self, query: &Query) -> Result; + + async fn search_in_partition( + &self, + partition_id: usize, + query: &Query, + pre_filter: Arc, + ) -> Result; + /// If the index is loadable by IVF, so it can be a sub-index that /// is loaded on demand by IVF. fn is_loadable(&self) -> bool; @@ -136,6 +149,9 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// explaining why not fn check_can_remap(&self) -> Result<()>; + // async fn append(&self, batches: Vec) -> Result<()>; + // async fn merge(&self, indices: Vec>) -> Result<()>; + /// Load the index from the reader on-demand. async fn load( &self, @@ -170,4 +186,10 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// The metric type of this vector index. fn metric_type(&self) -> DistanceType; + + fn ivf_model(&self) -> IvfModel; + fn quantizer(&self) -> Quantizer; + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType); } diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index 0a7c7d97f8..8ffbfe97c0 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -6,14 +6,16 @@ use std::{collections::HashSet, sync::Arc}; +use arrow::array::AsArray; use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use deepsize::DeepSizeOf; use itertools::Itertools; -use lance_core::{Result, ROW_ID_FIELD}; +use lance_core::{Error, Result, ROW_ID_FIELD}; use lance_file::reader::FileReader; use lance_linalg::distance::DistanceType; use serde::{Deserialize, Serialize}; +use snafu::{location, Location}; use crate::{ prefilter::PreFilter, @@ -61,6 +63,10 @@ impl IvfSubIndex for FlatIndex { "FLAT" } + fn metadata_key() -> &'static str { + "lance:flat" + } + fn schema() -> arrow_schema::SchemaRef { Schema::new(vec![Field::new("__flat_marker", DataType::UInt64, false)]).into() } @@ -74,25 +80,44 @@ impl IvfSubIndex for FlatIndex { prefilter: Arc, ) -> Result { let dist_calc = storage.dist_calculator(query); - let filtered_row_ids = prefilter - .filter_row_ids(Box::new(storage.row_ids())) - .into_iter() - .collect::>(); - let (row_ids, dists): (Vec, Vec) = (0..storage.len()) - .filter(|&id| !filtered_row_ids.contains(&storage.row_id(id as u32))) - .map(|id| OrderedNode { - id: id as u32, - dist: OrderedFloat(dist_calc.distance(id as u32)), - }) - .sorted_unstable() - .take(k) - .map( - |OrderedNode { - id, - dist: OrderedFloat(dist), - }| (storage.row_id(id), dist), - ) - .unzip(); + + let (row_ids, dists): (Vec, Vec) = match prefilter.is_empty() { + true => (0..storage.len()) + .map(|id| OrderedNode { + id: id as u32, + dist: OrderedFloat(dist_calc.distance(id as u32)), + }) + .sorted_unstable() + .take(k) + .map( + |OrderedNode { + id, + dist: OrderedFloat(dist), + }| (storage.row_id(id), dist), + ) + .unzip(), + false => { + let filtered_row_ids = prefilter + .filter_row_ids(Box::new(storage.row_ids())) + .into_iter() + .collect::>(); + (0..storage.len()) + .filter(|&id| !filtered_row_ids.contains(&storage.row_id(id as u32))) + .map(|id| OrderedNode { + id: id as u32, + dist: OrderedFloat(dist_calc.distance(id as u32)), + }) + .sorted_unstable() + .take(k) + .map( + |OrderedNode { + id, + dist: OrderedFloat(dist), + }| (storage.row_id(id), dist), + ) + .unzip() + } + }; let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists)); @@ -143,9 +168,15 @@ impl FlatQuantizer { } impl Quantization for FlatQuantizer { + type BuildParams = (); type Metadata = FlatMetadata; type Storage = FlatStorage; + fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result { + let dim = data.as_fixed_size_list().value_length(); + Ok(Self::new(dim as usize, distance_type)) + } + fn code_dim(&self) -> usize { self.dim } @@ -173,7 +204,7 @@ impl Quantization for FlatQuantizer { "flat" } - fn quantization_type(&self) -> QuantizationType { + fn quantization_type() -> QuantizationType { QuantizationType::Flat } @@ -187,3 +218,17 @@ impl From for Quantizer { Self::Flat(value) } } + +impl TryFrom for FlatQuantizer { + type Error = Error; + + fn try_from(value: Quantizer) -> Result { + match value { + Quantizer::Flat(quantizer) => Ok(quantizer), + _ => Err(Error::invalid_input( + "quantizer is not FlatQuantizer", + location!(), + )), + } + } +} diff --git a/rust/lance-index/src/vector/flat/storage.rs b/rust/lance-index/src/vector/flat/storage.rs index 62769dc151..7173697233 100644 --- a/rust/lance-index/src/vector/flat/storage.rs +++ b/rust/lance-index/src/vector/flat/storage.rs @@ -26,7 +26,7 @@ use super::index::FlatMetadata; pub const FLAT_COLUMN: &str = "flat"; /// All data are stored in memory -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct FlatStorage { batch: RecordBatch, distance_type: DistanceType, diff --git a/rust/lance-index/src/vector/graph.rs b/rust/lance-index/src/vector/graph.rs index fde0d5229e..3cc1bd425d 100644 --- a/rust/lance-index/src/vector/graph.rs +++ b/rust/lance-index/src/vector/graph.rs @@ -232,9 +232,8 @@ pub fn beam_search( dist_calc: &impl DistCalculator, bitset: Option<&roaring::bitmap::RoaringBitmap>, prefetch_distance: Option, - visited_generator: &mut VisitedGenerator, + visited: &mut Visited, ) -> Vec { - let mut visited = visited_generator.generate(graph.len()); //let mut visited: HashSet<_> = HashSet::with_capacity(k); let mut candidates = BinaryHeap::with_capacity(k); visited.insert(ep.id); diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index 072efe6b0a..e9c487b4d0 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -54,9 +54,6 @@ pub struct HnswBuildParams { /// size of the dynamic list for the candidates pub ef_construction: usize, - /// the max number of threads to use for building the graph - pub parallel_limit: Option, - /// number of vectors ahead to prefetch while building the graph pub prefetch_distance: Option, } @@ -67,7 +64,6 @@ impl Default for HnswBuildParams { max_level: 7, m: 20, ef_construction: 150, - parallel_limit: None, prefetch_distance: Some(2), } } @@ -97,17 +93,15 @@ impl HnswBuildParams { self } - /// The max number of threads to use for building the graph. - pub fn parallel_limit(mut self, limit: usize) -> Self { - self.parallel_limit = Some(limit); - self - } - - pub async fn build(self, data: ArrayRef) -> Result { - // We have normalized the vectors if the metric type is cosine, so we can use the L2 distance + /// Build the HNSW index from the given data. + /// + /// # Parameters + /// - `data`: A FixedSizeList to build the HNSW. + /// - `distance_type`: The distance type to use. + pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result { let vec_store = Arc::new(FlatStorage::new( data.as_fixed_size_list().clone(), - DistanceType::L2, + distance_type, )); HNSW::index_vectors(vec_store.as_ref(), self) } @@ -132,6 +126,18 @@ impl Debug for HNSW { } impl HNSW { + pub fn empty() -> Self { + Self { + inner: Arc::new(HnswBuilder { + params: HnswBuildParams::default(), + nodes: Arc::new(Vec::new()), + level_count: Vec::new(), + entry_point: 0, + visited_generator_queue: Arc::new(ArrayQueue::new(1)), + }), + } + } + pub fn len(&self) -> usize { self.inner.nodes.len() } @@ -172,6 +178,7 @@ impl HNSW { } let bottom_level = HnswBottomView::new(nodes); + let mut visited = visited_generator.generate(storage.len()); Ok(beam_search( &bottom_level, &ep, @@ -179,7 +186,7 @@ impl HNSW { &dist_calc, bitset.as_ref(), prefetch_distance, - visited_generator, + &mut visited, ) .into_iter() .take(k) @@ -411,6 +418,7 @@ impl HnswBuilder { visited_generator: &mut VisitedGenerator, ) -> Vec { let cur_level = HnswLevelView::new(level, nodes); + let mut visited = visited_generator.generate(nodes.len()); beam_search( &cur_level, ep, @@ -418,7 +426,7 @@ impl HnswBuilder { dist_calc, None, self.params.prefetch_distance, - visited_generator, + &mut visited, ) } @@ -507,6 +515,10 @@ impl IvfSubIndex for HNSW { where Self: Sized, { + if data.num_rows() == 0 { + return Ok(Self::empty()); + } + let hnsw_metadata = data.schema_ref() .metadata() @@ -515,7 +527,14 @@ impl IvfSubIndex for HNSW { message: format!("{} not found", HNSW_METADATA_KEY), location: location!(), })?; - let hnsw_metadata: HnswMetadata = serde_json::from_str(hnsw_metadata)?; + let hnsw_metadata: HnswMetadata = + serde_json::from_str(hnsw_metadata).map_err(|e| Error::Index { + message: format!( + "Failed to decode HNSW metadata: {}, json: {}", + e, hnsw_metadata + ), + location: location!(), + })?; let levels: Vec<_> = hnsw_metadata .level_offsets @@ -579,6 +598,10 @@ impl IvfSubIndex for HNSW { HNSW_TYPE } + fn metadata_key() -> &'static str { + "lance:hnsw" + } + /// Return the schema of the sub index fn schema() -> arrow_schema::SchemaRef { arrow_schema::Schema::new(vec![ @@ -648,21 +671,15 @@ impl IvfSubIndex for HNSW { }; log::info!( - "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}", + "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}", storage.len(), hnsw.inner.params.max_level, hnsw.inner.params.m, - hnsw.inner.params.ef_construction + hnsw.inner.params.ef_construction, + storage.distance_type(), ); let len = storage.len(); - let parallel_limit = hnsw - .inner - .params - .parallel_limit - .unwrap_or_else(num_cpus::get) - .max(1); - log::info!("Building HNSW graph with parallel_limit={}", parallel_limit); hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed); (1..len).into_par_iter().for_each(|node| { let mut visited_generator = VisitedGenerator::new(len); diff --git a/rust/lance-index/src/vector/hnsw/index.rs b/rust/lance-index/src/vector/hnsw/index.rs index 01a4b456f7..783372b9f1 100644 --- a/rust/lance-index/src/vector/hnsw/index.rs +++ b/rust/lance-index/src/vector/hnsw/index.rs @@ -8,7 +8,7 @@ use std::{ sync::Arc, }; -use arrow_array::RecordBatch; +use arrow_array::{RecordBatch, UInt32Array}; use async_trait::async_trait; use deepsize::DeepSizeOf; use lance_core::{datatypes::Schema, Error, Result}; @@ -22,7 +22,9 @@ use snafu::{location, Location}; use tracing::instrument; use crate::prefilter::PreFilter; -use crate::vector::v3::subindex::IvfSubIndex; +use crate::vector::ivf::storage::IvfModel; +use crate::vector::quantizer::QuantizationType; +use crate::vector::v3::subindex::{IvfSubIndex, SubIndexType}; use crate::{ vector::{ graph::NEIGHBORS_FIELD, @@ -35,8 +37,6 @@ use crate::{ Index, IndexType, }; -use super::builder::HNSW_METADATA_KEY; - #[derive(Clone, DeepSizeOf)] pub struct HNSWIndexOptions { pub use_residual: bool, @@ -118,6 +118,11 @@ impl Index for HNSWIndex { self } + /// Cast to [VectorIndex] + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + /// Retrieve index statistics as a JSON Value fn statistics(&self) -> Result { Ok(json!({ @@ -166,6 +171,19 @@ impl VectorIndex for HNSWIndex { ) } + fn find_partitions(&self, _: &Query) -> Result { + unimplemented!("only for IVF") + } + + async fn search_in_partition( + &self, + _: usize, + _: &Query, + _: Arc, + ) -> Result { + unimplemented!("only for IVF") + } + fn is_loadable(&self) -> bool { true } @@ -230,7 +248,7 @@ impl VectorIndex for HNSWIndex { .await?; let mut schema = batch.schema_ref().as_ref().clone(); schema.metadata.insert( - HNSW_METADATA_KEY.to_string(), + HNSW::metadata_key().to_owned(), serde_json::to_string(&metadata)?, ); let batch = batch.with_schema(schema.into())?; @@ -256,6 +274,21 @@ impl VectorIndex for HNSWIndex { }) } + fn ivf_model(&self) -> IvfModel { + unimplemented!("only for IVF") + } + + fn quantizer(&self) -> Quantizer { + self.partition_storage.quantizer().clone() + } + + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + ( + SubIndexType::Hnsw, + self.partition_storage.quantizer().quantization_type(), + ) + } + fn metric_type(&self) -> DistanceType { self.partition_storage.distance_type() } diff --git a/rust/lance-index/src/vector/ivf.rs b/rust/lance-index/src/vector/ivf.rs index 4e5b8425d8..87fcdf164d 100644 --- a/rust/lance-index/src/vector/ivf.rs +++ b/rust/lance-index/src/vector/ivf.rs @@ -15,14 +15,13 @@ use lance_linalg::{ kmeans::{compute_partitions_arrow_array, kmeans_find_partitions_arrow_array}, }; -use crate::vector::ivf::transform::IvfTransformer; +use crate::vector::ivf::transform::PartitionTransformer; use crate::vector::{ pq::{transform::PQTransformer, ProductQuantizer}, residual::ResidualTransform, transform::Transformer, }; -use super::transform::DropColumn; use super::{quantizer::Quantizer, residual::compute_residual}; use super::{PART_ID_COLUMN, PQ_CODE_COLUMN, RESIDUAL_COLUMN}; @@ -40,38 +39,48 @@ mod transform; /// - *metric_type*: metric type to compute pair-wise vector distance. /// - *transforms*: a list of transforms to apply to the vector column. /// - *range*: only covers a range of partitions. Default is None -pub fn new_ivf( +pub fn new_ivf_transformer( centroids: FixedSizeListArray, metric_type: DistanceType, transforms: Vec>, -) -> Ivf { - Ivf::new(centroids, metric_type, transforms) +) -> IvfTransformer { + IvfTransformer::new(centroids, metric_type, transforms) } -pub fn new_ivf_with_quantizer( +pub fn new_ivf_transformer_with_quantizer( centroids: FixedSizeListArray, metric_type: MetricType, vector_column: &str, quantizer: Quantizer, range: Option>, -) -> Result { +) -> Result { match quantizer { - Quantizer::Flat(_) => Ok(Ivf::new_flat(centroids, metric_type, vector_column, range)), - Quantizer::Product(pq) => Ok(Ivf::with_pq( + Quantizer::Flat(_) => Ok(IvfTransformer::new_flat( + centroids, + metric_type, + vector_column, + range, + )), + Quantizer::Product(pq) => Ok(IvfTransformer::with_pq( centroids, metric_type, vector_column, pq, range, )), - Quantizer::Scalar(_) => Ok(Ivf::with_sq(centroids, metric_type, vector_column, range)), + Quantizer::Scalar(_) => Ok(IvfTransformer::with_sq( + centroids, + metric_type, + vector_column, + range, + )), } } /// IVF - IVF file partition /// #[derive(Debug)] -pub struct Ivf { +pub struct IvfTransformer { /// Centroids of a cluster algorithm, to run IVF. /// /// It is a 2-D `(num_partitions * dimension)` of floating array. @@ -84,7 +93,7 @@ pub struct Ivf { distance_type: DistanceType, } -impl Ivf { +impl IvfTransformer { /// Create a new Ivf model. pub fn new( centroids: FixedSizeListArray, @@ -115,7 +124,11 @@ impl Ivf { distance_type }; - let ivf_transform = Arc::new(IvfTransformer::new(centroids.clone(), dt, vector_column)); + let ivf_transform = Arc::new(PartitionTransformer::new( + centroids.clone(), + dt, + vector_column, + )); transforms.push(ivf_transform.clone()); if let Some(range) = range { @@ -151,8 +164,12 @@ impl Ivf { distance_type }; - let ivf_transform = Arc::new(IvfTransformer::new(centroids.clone(), mt, vector_column)); - transforms.push(ivf_transform.clone()); + let partition_transform = Arc::new(PartitionTransformer::new( + centroids.clone(), + mt, + vector_column, + )); + transforms.push(partition_transform.clone()); if let Some(range) = range { transforms.push(Arc::new(transform::PartitionFilter::new( @@ -203,8 +220,12 @@ impl Ivf { metric_type }; - let ivf_transform = Arc::new(IvfTransformer::new(centroids.clone(), mt, vector_column)); - transforms.push(ivf_transform.clone()); + let partition_transformer = Arc::new(PartitionTransformer::new( + centroids.clone(), + mt, + vector_column, + )); + transforms.push(partition_transformer.clone()); if let Some(range) = range { transforms.push(Arc::new(transform::PartitionFilter::new( @@ -213,9 +234,6 @@ impl Ivf { ))); } - // For SQ we will transofrm the vector to SQ code while building the index, - // so simply drop the vector column now. - transforms.push(Arc::new(DropColumn::new(vector_column))); Self { centroids, distance_type: metric_type, @@ -243,7 +261,7 @@ impl Ivf { } } -impl Transformer for Ivf { +impl Transformer for IvfTransformer { fn transform(&self, batch: &RecordBatch) -> Result { let mut batch = batch.clone(); for transform in self.transforms.as_slice() { diff --git a/rust/lance-index/src/vector/ivf/shuffler.rs b/rust/lance-index/src/vector/ivf/shuffler.rs index dd02ff9566..e76387b359 100644 --- a/rust/lance-index/src/vector/ivf/shuffler.rs +++ b/rust/lance-index/src/vector/ivf/shuffler.rs @@ -33,7 +33,7 @@ use object_store::path::Path; use snafu::{location, Location}; use tempfile::TempDir; -use crate::vector::ivf::Ivf; +use crate::vector::ivf::IvfTransformer; use crate::vector::transform::{KeepFiniteVectors, Transformer}; use crate::vector::PART_ID_COLUMN; @@ -70,7 +70,7 @@ fn get_temp_dir() -> Result { pub async fn shuffle_dataset( data: impl RecordBatchStream + Unpin + 'static, column: &str, - ivf: Arc, + ivf: Arc, precomputed_partitions: Option>, num_partitions: u32, shuffle_partition_batches: usize, diff --git a/rust/lance-index/src/vector/ivf/storage.rs b/rust/lance-index/src/vector/ivf/storage.rs index dc94351245..9d5a8fafad 100644 --- a/rust/lance-index/src/vector/ivf/storage.rs +++ b/rust/lance-index/src/vector/ivf/storage.rs @@ -2,13 +2,14 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::ops::Range; -use std::sync::Arc; -use arrow_array::{Array, FixedSizeListArray}; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt32Array}; use deepsize::DeepSizeOf; +use itertools::Itertools; use lance_core::{Error, Result}; use lance_file::{reader::FileReader, writer::FileWriter}; use lance_io::{traits::WriteExt, utils::read_message}; +use lance_linalg::distance::DistanceType; use lance_table::io::manifest::ManifestDescribing; use log::debug; use serde::{Deserialize, Serialize}; @@ -19,53 +20,110 @@ use crate::pb::Ivf as PbIvf; pub const IVF_METADATA_KEY: &str = "lance:ivf"; pub const IVF_PARTITION_KEY: &str = "lance:ivf:partition"; +/// Ivf Model #[derive(Debug, Clone, PartialEq)] -pub struct IvfData { - /// Centroids of the IVF indices. Can be empty. - centroids: Option>, +pub struct IvfModel { + /// Centroids of each partition. + /// + /// It is a 2-D `(num_partitions * dimension)` of vector array. + pub centroids: Option, - /// Length of each partition. - lengths: Vec, + /// Offset of each partition in the file. + pub offsets: Vec, - /// pre-computed row offset for each partition, do not persist. - partition_row_offsets: Vec, + /// Number of vectors in each partition. + pub lengths: Vec, } -impl DeepSizeOf for IvfData { +impl DeepSizeOf for IvfModel { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.centroids .as_ref() .map(|centroids| centroids.get_array_memory_size()) - .unwrap_or(0) + .unwrap_or_default() + self.lengths.deep_size_of_children(context) - + self.partition_row_offsets.deep_size_of_children(context) + + self.offsets.deep_size_of_children(context) } } -/// The IVF metadata stored in the Lance Schema -#[derive(Serialize, Deserialize, Debug)] -struct IvfMetadata { - // The file position to store the protobuf binary of IVF metadata. - pb_position: usize, -} - -impl IvfData { +impl IvfModel { pub fn empty() -> Self { Self { centroids: None, + offsets: vec![], lengths: vec![], - partition_row_offsets: vec![0], } } - pub fn with_centroids(centroids: Arc) -> Self { + pub fn new(centroids: FixedSizeListArray) -> Self { Self { centroids: Some(centroids), + offsets: vec![], lengths: vec![], - partition_row_offsets: vec![0], } } + pub fn centroid(&self, partition: usize) -> Option { + self.centroids.as_ref().map(|c| c.value(partition)) + } + + /// Ivf model dimension. + pub fn dimension(&self) -> usize { + self.centroids + .as_ref() + .map(|c| c.value_length() as usize) + .unwrap_or(0) + } + + /// Number of IVF partitions. + pub fn num_partitions(&self) -> usize { + self.centroids + .as_ref() + .map(|c| c.len()) + .unwrap_or_else(|| self.offsets.len()) + } + + pub fn partition_size(&self, part: usize) -> usize { + self.lengths[part] as usize + } + + /// Use the query vector to find `nprobes` closest partitions. + pub fn find_partitions( + &self, + query: &dyn Array, + nprobes: usize, + distance_type: DistanceType, + ) -> Result { + let internal = crate::vector::ivf::new_ivf_transformer( + self.centroids.clone().unwrap(), + distance_type, + vec![], + ); + internal.find_partitions(query, nprobes) + } + + /// Add the offset and length of one partition. + pub fn add_partition(&mut self, len: u32) { + self.offsets.push( + self.offsets.last().cloned().unwrap_or_default() + + self.lengths.last().cloned().unwrap_or_default() as usize, + ); + self.lengths.push(len); + } + + /// Add the offset and length of one partition with the given offset. + /// this is used for old index format of IVF_PQ. + pub fn add_partition_with_offset(&mut self, offset: usize, len: u32) { + self.offsets.push(offset); + self.lengths.push(len); + } + + pub fn row_range(&self, partition: usize) -> Range { + let start = self.offsets[partition]; + let end = start + self.lengths[partition] as usize; + start..end + } + pub async fn load(reader: &FileReader) -> Result { let schema = reader.schema(); let meta_str = schema.metadata.get(IVF_METADATA_KEY).ok_or(Error::Index { @@ -94,37 +152,32 @@ impl IvfData { writer.add_metadata(IVF_METADATA_KEY, &serde_json::to_string(&ivf_metadata)?); Ok(()) } +} - pub fn add_partition(&mut self, num_rows: u32) { - self.lengths.push(num_rows); - let last_offset = self.partition_row_offsets.last().copied().unwrap_or(0); - self.partition_row_offsets - .push(last_offset + num_rows as usize); - } +/// Convert IvfModel to protobuf. +impl TryFrom<&IvfModel> for PbIvf { + type Error = Error; - pub fn has_centroids(&self) -> bool { - self.centroids.is_some() - } + fn try_from(ivf: &IvfModel) -> Result { + let lengths = ivf.lengths.clone(); - pub fn num_partitions(&self) -> usize { - self.lengths.len() - } - - /// Range of the rows for one partition. - pub fn row_range(&self, partition: usize) -> Range { - let start = self.partition_row_offsets[partition]; - let end = self.partition_row_offsets[partition + 1]; - start..end + Ok(Self { + centroids: vec![], // Deprecated + lengths, + offsets: ivf.offsets.iter().map(|x| *x as u64).collect(), + centroids_tensor: ivf.centroids.as_ref().map(|c| c.try_into()).transpose()?, + }) } } -impl TryFrom for IvfData { +/// Convert IvfModel to protobuf. +impl TryFrom for IvfModel { type Error = Error; fn try_from(proto: PbIvf) -> Result { let centroids = if let Some(tensor) = proto.centroids_tensor.as_ref() { debug!("Ivf: loading IVF centroids from index format v2"); - Some(Arc::new(FixedSizeListArray::try_from(tensor)?)) + Some(FixedSizeListArray::try_from(tensor)?) } else { None }; @@ -132,42 +185,38 @@ impl TryFrom for IvfData { // v1 index format. It will be deprecated soon. // // This new offset uses the row offset in the lance file. - let offsets = [0] - .iter() - .chain(proto.lengths.iter()) - .scan(0_usize, |state, &x| { - *state += x as usize; - Some(*state) - }); + let offsets = match proto.offsets.len() { + 0 => proto + .lengths + .iter() + .scan(0_usize, |state, &x| { + let old = *state; + *state += x as usize; + Some(old) + }) + .collect_vec(), + _ => proto.offsets.iter().map(|x| *x as usize).collect(), + }; + assert_eq!(offsets.len(), proto.lengths.len()); Ok(Self { centroids, + offsets, lengths: proto.lengths.clone(), - partition_row_offsets: offsets.collect(), }) } } -impl TryFrom<&IvfData> for PbIvf { - type Error = Error; - - fn try_from(meta: &IvfData) -> Result { - let lengths = meta.lengths.clone(); - - Ok(Self { - centroids: vec![], // Deprecated - lengths, - offsets: vec![], // Deprecated - centroids_tensor: meta - .centroids - .as_ref() - .map(|c| c.as_ref().try_into()) - .transpose()?, - }) - } +/// The IVF metadata stored in the Lance Schema +#[derive(Serialize, Deserialize, Debug)] +struct IvfMetadata { + // The file position to store the protobuf binary of IVF metadata. + pb_position: usize, } #[cfg(test)] mod tests { + use std::sync::Arc; + use arrow_array::{Float32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use lance_core::datatypes::Schema; @@ -179,7 +228,7 @@ mod tests { #[test] fn test_ivf_find_rows() { - let mut ivf = IvfData::empty(); + let mut ivf = IvfModel::empty(); ivf.add_partition(20); ivf.add_partition(50); @@ -189,7 +238,7 @@ mod tests { #[tokio::test] async fn test_write_and_load() { - let mut ivf = IvfData::empty(); + let mut ivf = IvfModel::empty(); ivf.add_partition(20); ivf.add_partition(50); @@ -219,7 +268,7 @@ mod tests { .unwrap(); assert!(reader.schema().metadata.contains_key(IVF_METADATA_KEY)); - let ivf2 = IvfData::load(&reader).await.unwrap(); + let ivf2 = IvfModel::load(&reader).await.unwrap(); assert_eq!(ivf, ivf2); assert_eq!(ivf2.num_partitions(), 2); } diff --git a/rust/lance-index/src/vector/ivf/transform.rs b/rust/lance-index/src/vector/ivf/transform.rs index 81bd5e376f..d05bb8d090 100644 --- a/rust/lance-index/src/vector/ivf/transform.rs +++ b/rust/lance-index/src/vector/ivf/transform.rs @@ -30,14 +30,14 @@ use super::PART_ID_COLUMN; /// this transform is a Noop. /// #[derive(Debug)] -pub struct IvfTransformer { +pub struct PartitionTransformer { centroids: FixedSizeListArray, distance_type: DistanceType, input_column: String, output_column: String, } -impl IvfTransformer { +impl PartitionTransformer { pub fn new( centroids: FixedSizeListArray, distance_type: DistanceType, @@ -60,7 +60,7 @@ impl IvfTransformer { .into() } } -impl Transformer for IvfTransformer { +impl Transformer for PartitionTransformer { fn transform(&self, batch: &RecordBatch) -> Result { if batch.column_by_name(&self.output_column).is_some() { // If the partition ID column is already present, we don't need to compute it again. diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 37473aa0b7..9eee17c681 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -14,7 +14,7 @@ use lance_arrow::*; use lance_core::{Error, Result}; use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, DistanceType, Dot, L2}; use lance_linalg::kernels::argmin_value_float; -use lance_linalg::kmeans::kmeans_find_partitions; +use lance_linalg::kmeans::compute_partition; use lance_linalg::{distance::MetricType, MatrixView}; use rayon::prelude::*; use snafu::{location, Location}; @@ -360,14 +360,15 @@ where location: location!(), })?; + let sub_dim = dim / num_sub_vectors; let values = flatten_data .as_slice() .par_chunks(dim) .map(|vector| { vector - .chunks_exact(dim / num_sub_vectors) + .chunks_exact(sub_dim) .enumerate() - .flat_map(|(sub_idx, sub_vector)| { + .map(|(sub_idx, sub_vector)| { let centroids = get_sub_vector_centroids( codebook.as_slice(), dim, @@ -375,9 +376,7 @@ where num_sub_vectors, sub_idx, ); - let parts = kmeans_find_partitions(centroids, sub_vector, 1, distance_type) - .expect("kmeans_find_partitions failed"); - parts.values().iter().map(|v| *v as u8).collect::>() + compute_partition(centroids, sub_vector, distance_type).map(|v| v as u8) }) .collect::>() }) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 87b0fcd19e..4171fc40da 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -32,6 +32,7 @@ use serde::{Deserialize, Serialize}; use snafu::{location, Location}; use super::{distance::build_distance_table_l2, num_centroids, ProductQuantizerImpl}; +use crate::vector::storage::STORAGE_METADATA_KEY; use crate::{ pb, vector::{ @@ -398,7 +399,7 @@ impl VectorStore for ProductQuantizationStorage { let metadata_json = batch .schema_ref() .metadata() - .get("metadata") + .get(STORAGE_METADATA_KEY) .ok_or(Error::Index { message: "Metadata not found in schema".to_string(), location: location!(), @@ -456,7 +457,7 @@ impl VectorStore for ProductQuantizationStorage { }; let metadata_json = serde_json::to_string(&metadata)?; - let metadata = HashMap::from_iter(vec![("metadata".to_string(), metadata_json)]); + let metadata = HashMap::from_iter(vec![(STORAGE_METADATA_KEY.to_string(), metadata_json)]); let schema = self .batch diff --git a/rust/lance-index/src/vector/pq/utils.rs b/rust/lance-index/src/vector/pq/utils.rs index 08161947ce..80d5ff34d8 100644 --- a/rust/lance-index/src/vector/pq/utils.rs +++ b/rust/lance-index/src/vector/pq/utils.rs @@ -54,6 +54,7 @@ pub fn num_centroids(num_bits: impl Into) -> usize { 2_usize.pow(num_bits.into()) } +#[inline] pub fn get_sub_vector_centroids( codebook: &[T], dimension: usize, diff --git a/rust/lance-index/src/vector/quantizer.rs b/rust/lance-index/src/vector/quantizer.rs index 550ec463bb..83ea913d54 100644 --- a/rust/lance-index/src/vector/quantizer.rs +++ b/rust/lance-index/src/vector/quantizer.rs @@ -2,10 +2,13 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use core::fmt; +use std::fmt::Debug; use std::sync::Arc; -use arrow::datatypes::Float32Type; +use arrow::array::AsArray; +use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array}; +use arrow_schema::DataType; use async_trait::async_trait; use deepsize::DeepSizeOf; use lance_arrow::ArrowFloatType; @@ -22,9 +25,10 @@ use crate::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; use super::flat::index::FlatQuantizer; use super::pq::storage::PQ_METADTA_KEY; use super::pq::ProductQuantizer; +use super::sq::builder::SQBuildParams; use super::sq::storage::SQ_METADATA_KEY; use super::{ - ivf::storage::IvfData, + ivf::storage::IvfModel, pq::{ storage::{ProductQuantizationMetadata, ProductQuantizationStorage}, ProductQuantizerImpl, @@ -37,15 +41,21 @@ use super::{ }; use super::{PQ_CODE_COLUMN, SQ_CODE_COLUMN}; -pub trait Quantization: Send + Sync + DeepSizeOf + Into { +pub trait Quantization: Send + Sync + Debug + DeepSizeOf + Into { + type BuildParams: QuantizerBuildParams; type Metadata: QuantizerMetadata + Send + Sync; - type Storage: QuantizerStorage + VectorStore; + type Storage: QuantizerStorage + VectorStore + Debug; + fn build( + data: &dyn Array, + distance_type: DistanceType, + params: &Self::BuildParams, + ) -> Result; fn code_dim(&self) -> usize; fn column(&self) -> &'static str; fn quantize(&self, vectors: &dyn Array) -> Result; fn metadata_key() -> &'static str; - fn quantization_type(&self) -> QuantizationType; + fn quantization_type() -> QuantizationType; fn metadata(&self, _: Option) -> Result; fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result; } @@ -66,6 +76,16 @@ impl std::fmt::Display for QuantizationType { } } +pub trait QuantizerBuildParams { + fn sample_size(&self) -> usize; +} + +impl QuantizerBuildParams for () { + fn sample_size(&self) -> usize { + 0 + } +} + /// Quantization Method. /// ///
@@ -105,9 +125,9 @@ impl Quantizer { pub fn quantization_type(&self) -> QuantizationType { match self { - Self::Flat(fq) => fq.quantization_type(), - Self::Product(pq) => pq.quantization_type(), - Self::Scalar(sq) => sq.quantization_type(), + Self::Flat(_) => QuantizationType::Flat, + Self::Product(_) => QuantizationType::Product, + Self::Scalar(_) => QuantizationType::Scalar, } } @@ -168,9 +188,42 @@ pub trait QuantizerStorage: Clone + Sized + DeepSizeOf + VectorStore { } impl Quantization for ScalarQuantizer { + type BuildParams = SQBuildParams; type Metadata = ScalarQuantizationMetadata; type Storage = ScalarQuantizationStorage; + fn build(data: &dyn Array, _: DistanceType, params: &Self::BuildParams) -> Result { + let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index { + message: format!( + "SQ builder: input is not a FixedSizeList: {}", + data.data_type() + ), + location: location!(), + })?; + + let mut quantizer = Self::new(params.num_bits, fsl.value_length() as usize); + + match fsl.value_type() { + DataType::Float16 => { + quantizer.update_bounds::(fsl)?; + } + DataType::Float32 => { + quantizer.update_bounds::(fsl)?; + } + DataType::Float64 => { + quantizer.update_bounds::(fsl)?; + } + _ => { + return Err(Error::Index { + message: format!("SQ builder: unsupported data type: {}", fsl.value_type()), + location: location!(), + }) + } + } + + Ok(quantizer) + } + fn code_dim(&self) -> usize { self.dim } @@ -180,15 +233,22 @@ impl Quantization for ScalarQuantizer { } fn quantize(&self, vectors: &dyn Array) -> Result { - let code_array = self.transform::(vectors)?; - Ok(code_array) + match vectors.as_fixed_size_list().value_type() { + DataType::Float16 => self.transform::(vectors), + DataType::Float32 => self.transform::(vectors), + DataType::Float64 => self.transform::(vectors), + value_type => Err(Error::invalid_input( + format!("unsupported data type {} for scalar quantizer", value_type), + location!(), + )), + } } fn metadata_key() -> &'static str { SQ_METADATA_KEY } - fn quantization_type(&self) -> QuantizationType { + fn quantization_type() -> QuantizationType { QuantizationType::Scalar } @@ -210,9 +270,14 @@ impl Quantization for ScalarQuantizer { } impl Quantization for Arc { + type BuildParams = (); type Metadata = ProductQuantizationMetadata; type Storage = ProductQuantizationStorage; + fn build(_: &dyn Array, _: DistanceType, _: &Self::BuildParams) -> Result { + unimplemented!("ProductQuantizer cannot be built with new index builder") + } + fn code_dim(&self) -> usize { self.num_sub_vectors() } @@ -230,7 +295,7 @@ impl Quantization for Arc { PQ_METADTA_KEY } - fn quantization_type(&self) -> QuantizationType { + fn quantization_type() -> QuantizationType { QuantizationType::Product } @@ -278,9 +343,14 @@ impl Quantization for ProductQuantizerImpl where T::Native: Dot + L2, { + type BuildParams = (); type Metadata = ProductQuantizationMetadata; type Storage = ProductQuantizationStorage; + fn build(_: &dyn Array, _: DistanceType, _: &Self::BuildParams) -> Result { + unimplemented!("ProductQuantizer cannot be built with new index builder") + } + fn code_dim(&self) -> usize { self.num_sub_vectors() } @@ -298,7 +368,7 @@ where PQ_METADTA_KEY } - fn quantization_type(&self) -> QuantizationType { + fn quantization_type() -> QuantizationType { QuantizationType::Product } @@ -348,7 +418,7 @@ pub struct IvfQuantizationStorage { quantizer: Quantizer, metadata: Q::Metadata, - ivf: IvfData, + ivf: IvfModel, } impl DeepSizeOf for IvfQuantizationStorage { @@ -398,7 +468,7 @@ impl IvfQuantizationStorage { })?; let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?; - let ivf_data = IvfData::load(&reader).await?; + let ivf_data = IvfModel::load(&reader).await?; let metadata = Q::Metadata::load(&reader).await?; let quantizer = Q::from_metadata(&metadata, distance_type)?; diff --git a/rust/lance-index/src/vector/sq.rs b/rust/lance-index/src/vector/sq.rs index a3c7f5238c..1ef31a9dec 100644 --- a/rust/lance-index/src/vector/sq.rs +++ b/rust/lance-index/src/vector/sq.rs @@ -13,6 +13,8 @@ use lance_core::{Error, Result}; use num_traits::*; use snafu::{location, Location}; +use super::quantizer::Quantizer; + pub mod builder; pub mod storage; pub mod transform; @@ -129,6 +131,19 @@ impl ScalarQuantizer { } } +impl TryFrom for ScalarQuantizer { + type Error = Error; + fn try_from(value: Quantizer) -> Result { + match value { + Quantizer::Scalar(sq) => Ok(sq), + _ => Err(Error::Index { + message: "Expect to be a ScalarQuantizer".to_string(), + location: location!(), + }), + } + } +} + pub(crate) fn scale_to_u8(values: &[T::Native], bounds: Range) -> Vec { let range = bounds.end - bounds.start; values diff --git a/rust/lance-index/src/vector/sq/builder.rs b/rust/lance-index/src/vector/sq/builder.rs index fb885e2cd2..913751062c 100644 --- a/rust/lance-index/src/vector/sq/builder.rs +++ b/rust/lance-index/src/vector/sq/builder.rs @@ -1,18 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use arrow::{ - array::AsArray, - datatypes::{Float16Type, Float32Type, Float64Type}, -}; -use arrow_array::Array; - -use arrow_schema::DataType; -use lance_core::{Error, Result}; -use lance_linalg::distance::DistanceType; -use snafu::{location, Location}; - -use super::ScalarQuantizer; +use crate::vector::quantizer::QuantizerBuildParams; #[derive(Debug, Clone)] pub struct SQBuildParams { @@ -32,36 +21,8 @@ impl Default for SQBuildParams { } } -impl SQBuildParams { - pub fn build(&self, data: &dyn Array, _: DistanceType) -> Result { - let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index { - message: format!( - "SQ builder: input is not a FixedSizeList: {}", - data.data_type() - ), - location: location!(), - })?; - - let mut quantizer = ScalarQuantizer::new(self.num_bits, fsl.value_length() as usize); - - match fsl.value_type() { - DataType::Float16 => { - quantizer.update_bounds::(fsl)?; - } - DataType::Float32 => { - quantizer.update_bounds::(fsl)?; - } - DataType::Float64 => { - quantizer.update_bounds::(fsl)?; - } - _ => { - return Err(Error::Index { - message: format!("SQ builder: unsupported data type: {}", fsl.value_type()), - location: location!(), - }) - } - } - - Ok(quantizer) +impl QuantizerBuildParams for SQBuildParams { + fn sample_size(&self) -> usize { + self.sample_rate * 2usize.pow(self.num_bits as u32) } } diff --git a/rust/lance-index/src/vector/sq/storage.rs b/rust/lance-index/src/vector/sq/storage.rs index 52fa3ff84f..beb83f5851 100644 --- a/rust/lance-index/src/vector/sq/storage.rs +++ b/rust/lance-index/src/vector/sq/storage.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::{collections::HashMap, ops::Range, sync::Arc}; +use std::ops::Range; use arrow::compute::concat_batches; use arrow_array::{ @@ -21,6 +21,7 @@ use object_store::path::Path; use serde::{Deserialize, Serialize}; use snafu::{location, Location}; +use crate::vector::storage::STORAGE_METADATA_KEY; use crate::{ vector::{ quantizer::{QuantizerMetadata, QuantizerStorage}, @@ -70,7 +71,7 @@ impl QuantizerMetadata for ScalarQuantizationMetadata { } /// An immutable chunk of SclarQuantizationStorage. -#[derive(Clone)] +#[derive(Debug, Clone)] struct SQStorageChunk { batch: RecordBatch, @@ -150,7 +151,7 @@ impl DeepSizeOf for SQStorageChunk { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct ScalarQuantizationStorage { quantizer: ScalarQuantizer, @@ -291,7 +292,7 @@ impl VectorStore for ScalarQuantizationStorage { let metadata_json = batch .schema_ref() .metadata() - .get("metadata") + .get(STORAGE_METADATA_KEY) .ok_or(Error::Schema { message: "metadata not found".to_string(), location: location!(), @@ -302,27 +303,7 @@ impl VectorStore for ScalarQuantizationStorage { } fn to_batches(&self) -> Result> { - let metadata = ScalarQuantizationMetadata { - dim: self.chunks[0].dim(), - num_bits: self.quantizer.num_bits, - bounds: self.quantizer.bounds.clone(), - }; - let metadata_json = serde_json::to_string(&metadata)?; - let metadata = HashMap::from_iter(vec![("metadata".to_owned(), metadata_json)]); - - let schema = self.chunks[0] - .schema() - .as_ref() - .clone() - .with_metadata(metadata); - let schema = Arc::new(schema); - Ok(self.chunks.iter().map(move |chunk| { - chunk - .batch - .clone() - .with_schema(schema.clone()) - .expect("attach schema") - })) + Ok(self.chunks.iter().map(|c| c.batch.clone())) } fn append_batch(&self, batch: RecordBatch, vector_column: &str) -> Result { @@ -458,6 +439,7 @@ mod tests { use super::*; use std::iter::repeat_with; + use std::sync::Arc; use arrow_array::FixedSizeListArray; use arrow_schema::{DataType, Field, Schema}; diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index ae5834de8a..d11a132801 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -22,12 +22,13 @@ use snafu::{location, Location}; use crate::{ pb, vector::{ - ivf::storage::{IvfData, IVF_METADATA_KEY}, - quantizer::{Quantization, Quantizer}, + ivf::storage::{IvfModel, IVF_METADATA_KEY}, + quantizer::Quantization, }, INDEX_METADATA_SCHEMA_KEY, }; +use super::quantizer::Quantizer; use super::DISTANCE_TYPE_KEY; ///
@@ -40,6 +41,8 @@ pub trait DistCalculator { fn prefetch(&self, _id: u32) {} } +pub const STORAGE_METADATA_KEY: &str = "storage_metadata"; + /// Vector Storage is the abstraction to store the vectors. /// /// It can be in-memory or on-disk, raw vector or quantized vectors. @@ -121,13 +124,19 @@ impl StorageBuilder { location: location!(), })?; let code_array = self.quantizer.quantize(vectors.as_ref())?; - let batch = batch.drop_column(&self.column)?.try_with_column( - Field::new( - self.quantizer.column(), - code_array.data_type().clone(), - true, - ), - code_array, + let batch = batch + .try_with_column( + Field::new( + self.quantizer.column(), + code_array.data_type().clone(), + true, + ), + code_array, + )? + .drop_column(&self.column)?; + let batch = batch.add_metadata( + STORAGE_METADATA_KEY.to_owned(), + self.quantizer.metadata(None)?.to_string(), )?; Q::Storage::try_from_batch(batch, self.distance_type) } @@ -135,30 +144,27 @@ impl StorageBuilder { /// Loader to load partitioned PQ storage from disk. #[derive(Debug)] -pub struct IvfQuantizationStorage { +pub struct IvfQuantizationStorage { reader: FileReader, distance_type: DistanceType, - quantizer: Quantizer, - metadata: Q::Metadata, + metadata: Vec, - ivf: IvfData, + ivf: IvfModel, } -impl DeepSizeOf for IvfQuantizationStorage { +impl DeepSizeOf for IvfQuantizationStorage { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { - self.quantizer.deep_size_of_children(context) - + self.metadata.deep_size_of_children(context) - + self.ivf.deep_size_of_children(context) + self.metadata.deep_size_of_children(context) + self.ivf.deep_size_of_children(context) } } #[allow(dead_code)] -impl IvfQuantizationStorage { +impl IvfQuantizationStorage { /// Open a Loader. /// /// - pub async fn open(reader: FileReader) -> Result { + pub async fn try_new(reader: FileReader) -> Result { let schema = reader.schema(); let distance_type = DistanceType::try_from( @@ -185,34 +191,29 @@ impl IvfQuantizationStorage { location: location!(), })?; let ivf_bytes = reader.read_global_buffer(ivf_pos).await?; - let ivf = IvfData::try_from(pb::Ivf::decode(ivf_bytes)?)?; + let ivf = IvfModel::try_from(pb::Ivf::decode(ivf_bytes)?)?; - let quantizer_metadata: Q::Metadata = serde_json::from_str( + let metadata: Vec = serde_json::from_str( schema .metadata - .get(Q::metadata_key()) + .get(STORAGE_METADATA_KEY) .ok_or(Error::Index { - message: format!("{} not found", Q::metadata_key()), + message: format!("{} not found", STORAGE_METADATA_KEY), location: location!(), })? .as_str(), )?; - let quantizer = Q::from_metadata(&quantizer_metadata, distance_type)?; Ok(Self { reader, distance_type, - quantizer, - metadata: quantizer_metadata, + metadata, ivf, }) } - pub fn quantizer(&self) -> &Quantizer { - &self.quantizer - } - - pub fn metadata(&self) -> &Q::Metadata { - &self.metadata + pub fn quantizer(&self) -> Result { + let metadata = serde_json::from_str(&self.metadata[0])?; + Q::from_metadata(&metadata, self.distance_type) } /// Get the number of partitions in the storage. @@ -220,7 +221,7 @@ impl IvfQuantizationStorage { self.ivf.num_partitions() } - pub async fn load_partition(&self, part_id: usize) -> Result { + pub async fn load_partition(&self, part_id: usize) -> Result { let range = self.ivf.row_range(part_id); let batches = self .reader @@ -234,6 +235,10 @@ impl IvfQuantizationStorage { .await?; let schema = Arc::new(self.reader.schema().as_ref().into()); let batch = concat_batches(&schema, batches.iter())?; + let batch = batch.add_metadata( + STORAGE_METADATA_KEY.to_owned(), + self.metadata[part_id].clone(), + )?; Q::Storage::try_from_batch(batch, self.distance_type) } } diff --git a/rust/lance-index/src/vector/v3/shuffler.rs b/rust/lance-index/src/vector/v3/shuffler.rs index 6377526911..16d1911d37 100644 --- a/rust/lance-index/src/vector/v3/shuffler.rs +++ b/rust/lance-index/src/vector/v3/shuffler.rs @@ -35,7 +35,7 @@ pub trait ShuffleReader: Send + Sync { ) -> Result>>; /// Get the size of the partition by partition_id - fn partiton_size(&self, partition_id: usize) -> Result; + fn partition_size(&self, partition_id: usize) -> Result; } #[async_trait::async_trait] @@ -51,7 +51,7 @@ pub trait Shuffler: Send + Sync { } pub struct IvfShuffler { - object_store: ObjectStore, + object_store: Arc, output_dir: Path, num_partitions: usize, @@ -60,9 +60,9 @@ pub struct IvfShuffler { } impl IvfShuffler { - pub fn new(object_store: ObjectStore, output_dir: Path, num_partitions: usize) -> Self { + pub fn new(output_dir: Path, num_partitions: usize) -> Self { Self { - object_store, + object_store: Arc::new(ObjectStore::local()), output_dir, num_partitions, buffer_size: 4096, @@ -85,8 +85,6 @@ impl Shuffler for IvfShuffler { let mut partition_sizes = vec![0; self.num_partitions]; let mut first_pass = true; - let mut counter = 0; - let num_partitions = self.num_partitions; let mut parallel_sort_stream = data .map(|batch| { @@ -133,8 +131,8 @@ impl Shuffler for IvfShuffler { .map(|_| Vec::new()) .collect::>(); + let mut counter = 0; while let Some(shuffled) = parallel_sort_stream.next().await { - log::info!("shuffle batch: {}", counter); let shuffled = shuffled?; for (part_id, batches) in shuffled.into_iter().enumerate() { @@ -176,6 +174,7 @@ impl Shuffler for IvfShuffler { // do flush if counter % self.buffer_size == 0 { + log::info!("shuffle {} batches, flushing", counter); let mut futs = vec![]; for (part_id, writer) in writers.iter_mut().enumerate() { let batches = &partition_buffers[part_id]; @@ -208,27 +207,41 @@ impl Shuffler for IvfShuffler { writer.finish().await?; } - Ok(Box::new(IvfShufflerReader { - object_store: self.object_store.clone(), - output_dir: self.output_dir.clone(), + Ok(Box::new(IvfShufflerReader::new( + self.object_store.clone(), + self.output_dir.clone(), partition_sizes, - })) + ))) } } pub struct IvfShufflerReader { - object_store: ObjectStore, + object_store: Arc, output_dir: Path, partition_sizes: Vec, } +impl IvfShufflerReader { + pub fn new( + object_store: Arc, + output_dir: Path, + partition_sizes: Vec, + ) -> Self { + Self { + object_store, + output_dir, + partition_sizes, + } + } +} + #[async_trait::async_trait] impl ShuffleReader for IvfShufflerReader { async fn read_partition( &self, partition_id: usize, ) -> Result>> { - let scheduler = ScanScheduler::new(Arc::new(self.object_store.clone()), 32); + let scheduler = ScanScheduler::new(self.object_store.clone(), 32); let partition_path = self.output_dir.child(format!("ivf_{}.lance", partition_id)); let reader = FileReader::try_open( @@ -250,7 +263,7 @@ impl ShuffleReader for IvfShufflerReader { )))) } - fn partiton_size(&self, partition_id: usize) -> Result { + fn partition_size(&self, partition_id: usize) -> Result { Ok(self.partition_sizes[partition_id]) } } diff --git a/rust/lance-index/src/vector/v3/subindex.rs b/rust/lance-index/src/vector/v3/subindex.rs index 933af8f7f4..cd795e9599 100644 --- a/rust/lance-index/src/vector/v3/subindex.rs +++ b/rust/lance-index/src/vector/v3/subindex.rs @@ -1,19 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::fmt::Debug; use std::sync::Arc; use arrow_array::{ArrayRef, RecordBatch}; use deepsize::DeepSizeOf; -use lance_core::Result; +use lance_core::{Error, Result}; +use snafu::{location, Location}; use crate::vector::storage::VectorStore; +use crate::vector::{flat, hnsw}; use crate::{prefilter::PreFilter, vector::Query}; - -pub const SUB_INDEX_METADATA_KEY: &str = "sub_index_metadata"; - /// A sub index for IVF index -pub trait IvfSubIndex: Send + Sync + DeepSizeOf { +pub trait IvfSubIndex: Send + Sync + Debug + DeepSizeOf { type QueryParams: Send + Sync + for<'a> From<&'a Query>; type BuildParams: Clone; @@ -26,6 +26,8 @@ pub trait IvfSubIndex: Send + Sync + DeepSizeOf { fn name() -> &'static str; + fn metadata_key() -> &'static str; + /// Return the schema of the sub index fn schema() -> arrow_schema::SchemaRef; @@ -52,3 +54,32 @@ pub trait IvfSubIndex: Send + Sync + DeepSizeOf { /// Encode the sub index into a record batch fn to_batch(&self) -> Result; } + +pub enum SubIndexType { + Flat, + Hnsw, +} + +impl std::fmt::Display for SubIndexType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Flat => write!(f, "{}", flat::index::FlatIndex::name()), + Self::Hnsw => write!(f, "{}", hnsw::builder::HNSW::name()), + } + } +} + +impl TryFrom<&str> for SubIndexType { + type Error = Error; + + fn try_from(value: &str) -> Result { + match value { + "FLAT" => Ok(Self::Flat), + "HNSW" => Ok(Self::Hnsw), + _ => Err(Error::Index { + message: format!("unknown sub index type {}", value), + location: location!(), + }), + } + } +} diff --git a/rust/lance-linalg/src/distance/dot.rs b/rust/lance-linalg/src/distance/dot.rs index c84d43426d..226bfdae46 100644 --- a/rust/lance-linalg/src/distance/dot.rs +++ b/rust/lance-linalg/src/distance/dot.rs @@ -66,10 +66,10 @@ pub fn dot(from: &[T], to: &[T]) -> f32 { T::dot(from, to) } -/// Negative dot distance. +/// Negative [Dot] distance. #[inline] pub fn dot_distance(from: &[T], to: &[T]) -> f32 { - -T::dot(from, to) + 1.0 - T::dot(from, to) } /// Dot product diff --git a/rust/lance-linalg/src/kernels.rs b/rust/lance-linalg/src/kernels.rs index 1e9d3c20e1..a9aa41f89a 100644 --- a/rust/lance-linalg/src/kernels.rs +++ b/rust/lance-linalg/src/kernels.rs @@ -75,6 +75,7 @@ pub fn argmin_value( /// Returns the minimal value (float) and the index (argmin) from an Iterator. /// /// Return `None` if the iterator is empty or all are `Nan/Inf`. +#[inline] pub fn argmin_value_float(iter: impl Iterator) -> Option<(u32, T)> { let mut min_idx = None; let mut min_value = T::infinity(); diff --git a/rust/lance-linalg/src/kmeans.rs b/rust/lance-linalg/src/kmeans.rs index bb0e353638..80f3735b04 100644 --- a/rust/lance-linalg/src/kmeans.rs +++ b/rust/lance-linalg/src/kmeans.rs @@ -30,7 +30,7 @@ use rayon::prelude::*; use crate::distance::hamming::hamming; use crate::distance::{dot_distance_batch, DistanceType}; -use crate::kernels::argmax; +use crate::kernels::{argmax, argmin_value_float}; use crate::{ distance::{ l2::{l2_distance_batch, L2}, @@ -214,11 +214,11 @@ where let cluster_and_dists = match distance_type { DistanceType::L2 => data .par_chunks(dimension) - .map(|vec| argmin_value(l2_distance_batch(vec, centroids, dimension))) + .map(|vec| argmin_value_float(l2_distance_batch(vec, centroids, dimension))) .collect::>(), DistanceType::Dot => data .par_chunks(dimension) - .map(|vec| argmin_value(dot_distance_batch(vec, centroids, dimension))) + .map(|vec| argmin_value_float(dot_distance_batch(vec, centroids, dimension))) .collect::>(), _ => { panic!( @@ -500,8 +500,8 @@ impl KMeans { last_membership = Some(membership); if (loss - last_loss).abs() / last_loss < params.tolerance { info!( - "KMeans training: converged at iteration {} / {}, redo={}", - i, params.max_iters, redo + "KMeans training: converged at iteration {} / {}, redo={}, loss={}, last_loss={}, loss_diff={}", + i, params.max_iters, redo, loss, last_loss, (loss - last_loss).abs() / last_loss ); break; } @@ -685,22 +685,32 @@ pub fn compute_partitions( let dimension = dimension.as_(); vectors .par_chunks(dimension) - .map(|vec| { - argmin_value(match distance_type { - DistanceType::L2 => l2_distance_batch(vec, centroids, dimension), - DistanceType::Dot => dot_distance_batch(vec, centroids, dimension), - _ => { - panic!( - "KMeans::find_partitions: {} is not supported", - distance_type - ); - } - }) - .map(|(idx, _)| idx) - }) + .map(|vec| compute_partition(centroids, vec, distance_type)) .collect::>() } +#[inline] +pub fn compute_partition( + centroids: &[T], + vector: &[T], + distance_type: DistanceType, +) -> Option { + match distance_type { + DistanceType::L2 => { + argmin_value_float(l2_distance_batch(vector, centroids, vector.len())).map(|c| c.0) + } + DistanceType::Dot => { + argmin_value_float(dot_distance_batch(vector, centroids, vector.len())).map(|c| c.0) + } + _ => { + panic!( + "KMeans::compute_partition: distance type {} is not supported", + distance_type + ); + } + } +} + #[cfg(test)] mod tests { use std::iter::repeat; diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index a3930a42e9..7bb2192b97 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -442,7 +442,7 @@ impl FileFragment { with_row_id: bool, with_row_address: bool, ) -> Result { - let open_files = self.open_readers(projection, with_row_id, with_row_address); + let open_files = self.open_readers(projection); let deletion_vec_load = self.load_deletion_vector(&self.dataset.object_store, &self.metadata); @@ -460,10 +460,10 @@ impl FileFragment { let deletion_vec = deletion_vec?; let row_id_sequence = row_id_sequence?; - if opened_files.is_empty() { + if opened_files.is_empty() && !with_row_id && !with_row_address { return Err(Error::io( format!( - "Does not find any data file for schema: {}\nfragment_id={}", + "Did not find any data files for schema: {}\nfragment_id={}", projection, self.id() ), @@ -471,6 +471,8 @@ impl FileFragment { )); } + let num_physical_rows = self.physical_rows().await?; + let mut reader = FragmentReader::try_new( self.id(), deletion_vec, @@ -478,6 +480,7 @@ impl FileFragment { opened_files, ArrowSchema::from(projection), self.count_rows().await?, + num_physical_rows, )?; if with_row_id { @@ -498,8 +501,6 @@ impl FileFragment { &self, data_file: &DataFile, projection: Option<&Schema>, - with_row_id: bool, - with_row_address: bool, ) -> Result, Arc)>> { let full_schema = self.dataset.schema(); // The data file may contain fields that are not part of the dataset any longer, remove those @@ -510,7 +511,7 @@ impl FileFragment { if data_file.is_legacy_file() { let max_field_id = data_file.fields.iter().max().unwrap(); - if with_row_id || with_row_address || !schema_per_file.fields.is_empty() { + if !schema_per_file.fields.is_empty() { let path = self.dataset.data_dir().child(data_file.path.as_str()); let field_id_offset = Self::get_field_id_offset(data_file); let reader = FileReader::try_new_with_fragment_id( @@ -566,17 +567,10 @@ impl FileFragment { async fn open_readers( &self, projection: &Schema, - with_row_id: bool, - with_row_address: bool, ) -> Result, Arc)>> { let mut opened_files = vec![]; - for (i, data_file) in self.metadata.files.iter().enumerate() { - // TODO: do we still need to do this? - let with_row_id = (with_row_id || with_row_address) && i == 0; - if let Some((reader, schema)) = self - .open_reader(data_file, Some(projection), with_row_id, with_row_address) - .await? - { + for data_file in &self.metadata.files { + if let Some((reader, schema)) = self.open_reader(data_file, Some(projection)).await? { opened_files.push((reader, schema)); } } @@ -668,16 +662,16 @@ impl FileFragment { // Just open any file. All of them should have same size. let some_file = &self.metadata.files[0]; - let (reader, _) = self - .open_reader(some_file, None, false, false) - .await? - .ok_or_else(|| Error::Internal { - message: format!( - "The data file {} did not have any fields contained in the dataset schema", - some_file.path - ), - location: location!(), - })?; + let (reader, _) = + self.open_reader(some_file, None) + .await? + .ok_or_else(|| Error::Internal { + message: format!( + "The data file {} did not have any fields contained in the dataset schema", + some_file.path + ), + location: location!(), + })?; Ok(reader.len() as usize) } @@ -762,16 +756,13 @@ impl FileFragment { } let get_lengths = self.metadata.files.iter().map(|data_file| async move { - let (reader, _) = self - .open_reader(data_file, None, false, false) - .await? - .ok_or_else(|| { - Error::corrupt_file( - self.dataset.data_dir().child(data_file.path.clone()), - "did not have any fields in common with the dataset schema", - location!(), - ) - })?; + let (reader, _) = self.open_reader(data_file, None).await?.ok_or_else(|| { + Error::corrupt_file( + self.dataset.data_dir().child(data_file.path.clone()), + "did not have any fields in common with the dataset schema", + location!(), + ) + })?; Result::Ok(reader.len() as usize) }); let get_lengths = try_join_all(get_lengths); @@ -1207,8 +1198,11 @@ pub struct FragmentReader { /// If false, deleted rows will be removed from the batch, requiring a copy make_deletions_null: bool, - // total number of rows in the fragment + // total number of real rows in the fragment (num_physical_rows - num_deleted_rows) num_rows: usize, + + // total number of physical rows in the fragment (all rows, ignoring deletions) + num_physical_rows: usize, } // Custom clone impl needed because it is not easy to clone Box @@ -1231,6 +1225,7 @@ impl Clone for FragmentReader { with_row_addr: self.with_row_addr, make_deletions_null: self.make_deletions_null, num_rows: self.num_rows, + num_physical_rows: self.num_physical_rows, } } } @@ -1264,15 +1259,9 @@ impl FragmentReader { readers: Vec<(Box, Arc)>, output_schema: ArrowSchema, num_rows: usize, + num_physical_rows: usize, ) -> Result { - if readers.is_empty() { - return Err(Error::io( - "Cannot create FragmentReader with zero readers".to_string(), - location!(), - )); - } - - if let Some(legacy_reader) = readers[0].0.as_legacy_opt() { + if let Some(legacy_reader) = readers.first().and_then(|reader| reader.0.as_legacy_opt()) { let num_batches = legacy_reader.num_batches(); for reader in readers.iter().skip(1) { if let Some(other_legacy) = reader.0.as_legacy_opt() { @@ -1301,6 +1290,7 @@ impl FragmentReader { with_row_addr: false, make_deletions_null: false, num_rows, + num_physical_rows, }) } @@ -1350,7 +1340,7 @@ impl FragmentReader { /// use streams, the updater still needs to know the batch size in v1 so that it can create /// files with the same batch size. pub(crate) fn legacy_num_rows_in_batch(&self, batch_id: u32) -> Option { - if let Some(legacy_reader) = self.readers[0].0.as_legacy_opt() { + if let Some(legacy_reader) = self.readers.first().and_then(|r| r.0.as_legacy_opt()) { if batch_id < legacy_reader.num_batches() as u32 { Some(legacy_reader.num_rows_in_batch(batch_id as i32) as u32) } else { @@ -1513,7 +1503,7 @@ impl FragmentReader { batch_size: u32, read_fn: impl Fn(&dyn GenericFileReader, &Arc) -> Result, ) -> Result { - let total_num_rows = self.readers[0].0.len(); + let total_num_rows = self.num_physical_rows as u32; // Note that the fragment length might be considerably smaller if there are deleted rows. // E.g. if a fragment has 100 rows but rows 0..10 are deleted we still need to make // sure it is valid to read / take 0..100 diff --git a/rust/lance/src/dataset/optimize.rs b/rust/lance/src/dataset/optimize.rs index ed09a7999b..c32c4fa906 100644 --- a/rust/lance/src/dataset/optimize.rs +++ b/rust/lance/src/dataset/optimize.rs @@ -916,6 +916,7 @@ mod tests { use arrow_array::{Float32Array, Int64Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use arrow_select::concat::concat_batches; + use rstest::rstest; use tempfile::tempdir; use super::*; @@ -1113,8 +1114,9 @@ mod tests { .unwrap() } + #[rstest] #[tokio::test] - async fn test_compact_empty() { + async fn test_compact_empty(#[values(false, true)] use_legacy_format: bool) { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -1122,7 +1124,16 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); let reader = RecordBatchIterator::new(vec![].into_iter().map(Ok), Arc::new(schema)); - let mut dataset = Dataset::write(reader, test_uri, None).await.unwrap(); + let mut dataset = Dataset::write( + reader, + test_uri, + Some(WriteParams { + use_legacy_format, + ..Default::default() + }), + ) + .await + .unwrap(); let plan = plan_compaction(&dataset, &CompactionOptions::default()) .await @@ -1137,8 +1148,9 @@ mod tests { assert_eq!(dataset.manifest.version, 1); } + #[rstest] #[tokio::test] - async fn test_compact_all_good() { + async fn test_compact_all_good(#[values(false, true)] use_legacy_format: bool) { // Compact a table with nothing to do let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -1148,6 +1160,7 @@ mod tests { // Just one file let write_params = WriteParams { max_rows_per_file: 10_000, + use_legacy_format, ..Default::default() }; let dataset = Dataset::write(reader, test_uri, Some(write_params)) @@ -1165,6 +1178,7 @@ mod tests { let write_params = WriteParams { max_rows_per_file: 3_000, max_rows_per_group: 1_000, + use_legacy_format, mode: WriteMode::Overwrite, ..Default::default() }; @@ -1217,8 +1231,9 @@ mod tests { } } + #[rstest] #[tokio::test] - async fn test_compact_many() { + async fn test_compact_many(#[values(false, true)] use_legacy_format: bool) { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -1228,6 +1243,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(0, 1200))], data.schema()); let write_params = WriteParams { max_rows_per_file: 400, + use_legacy_format, ..Default::default() }; Dataset::write(reader, test_uri, Some(write_params)) @@ -1238,6 +1254,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(1200, 2000))], data.schema()); let write_params = WriteParams { max_rows_per_file: 1000, + use_legacy_format, mode: WriteMode::Append, ..Default::default() }; @@ -1255,6 +1272,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(3200, 600))], data.schema()); let write_params = WriteParams { max_rows_per_file: 300, + use_legacy_format, mode: WriteMode::Append, ..Default::default() }; @@ -1358,8 +1376,9 @@ mod tests { assert_eq!(fragment_ids, vec![3, 7, 8, 9, 10]); } + #[rstest] #[tokio::test] - async fn test_compact_data_files() { + async fn test_compact_data_files(#[values(false, true)] use_legacy_format: bool) { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -1370,6 +1389,7 @@ mod tests { let write_params = WriteParams { max_rows_per_file: 5_000, max_rows_per_group: 1_000, + use_legacy_format, ..Default::default() }; let mut dataset = Dataset::write(reader, test_uri, Some(write_params)) @@ -1439,8 +1459,9 @@ mod tests { assert_eq!(scanned_data, data); } + #[rstest] #[tokio::test] - async fn test_compact_deletions() { + async fn test_compact_deletions(#[values(false, true)] use_legacy_format: bool) { // For files that have few rows, we don't want to compact just 1 since // that won't do anything. But if there are deletions to materialize, // we want to do groups of 1. This test checks that. @@ -1453,6 +1474,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(0, 1000))], data.schema()); let write_params = WriteParams { max_rows_per_file: 1000, + use_legacy_format, ..Default::default() }; let mut dataset = Dataset::write(reader, test_uri, Some(write_params)) @@ -1490,8 +1512,9 @@ mod tests { assert!(fragments[0].metadata.deletion_file.is_none()); } + #[rstest] #[tokio::test] - async fn test_compact_distributed() { + async fn test_compact_distributed(#[values(false, true)] use_legacy_format: bool) { // Can run the tasks independently // Can provide subset of tasks to commit_compaction // Once committed, can't commit remaining tasks @@ -1504,6 +1527,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(0, 9000))], data.schema()); let write_params = WriteParams { max_rows_per_file: 1000, + use_legacy_format, ..Default::default() }; let mut dataset = Dataset::write(reader, test_uri, Some(write_params)) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index dcc2fe1e38..ff34d4d94b 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -2208,6 +2208,42 @@ mod test { assert_eq!(expected_i, actual_i); } + #[rstest] + #[tokio::test] + async fn test_only_row_id(#[values(false, true)] use_legacy_format: bool) { + let test_ds = TestVectorDataset::new(use_legacy_format).await.unwrap(); + let dataset = &test_ds.dataset; + + let mut scan = dataset.scan(); + scan.project::<&str>(&[]).unwrap().with_row_id(); + + let results = scan + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.num_rows(), 400); + assert_eq!( + batch.schema().as_ref(), + &ArrowSchema::new(vec![ArrowField::new(ROW_ID, DataType::UInt64, true,)]) + ); + + let expected_row_ids: Vec = (0..400).map(|i| i as u64).collect(); + let actual_row_ids: Vec = as_primitive_array::(batch.column(0).as_ref()) + .values() + .iter() + .copied() + .collect(); + assert_eq!(expected_row_ids, actual_row_ids); + } + #[tokio::test] async fn test_scan_unordered_with_row_id() { // This test doesn't make sense for v2 files, there is no way to get an out-of-order scan diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 88f635377a..d536b6a936 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -12,14 +12,19 @@ use async_trait::async_trait; use futures::{stream, StreamExt, TryStreamExt}; use itertools::Itertools; use lance_file::reader::FileReader; +use lance_file::v2; use lance_index::optimize::OptimizeOptions; use lance_index::pb::index::Implementation; use lance_index::scalar::expression::IndexInformationProvider; use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::scalar::ScalarIndex; use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::hnsw::HNSW; +use lance_index::vector::sq::ScalarQuantizer; pub use lance_index::IndexParams; +use lance_index::INDEX_METADATA_SCHEMA_KEY; use lance_index::{pb, vector::VectorIndex, DatasetIndexExt, Index, IndexType, INDEX_FILE_NAME}; +use lance_io::scheduler::ScanScheduler; use lance_io::traits::Reader; use lance_io::utils::{ read_last_block, read_message, read_message_from_buf, read_metadata_offset, read_version, @@ -36,7 +41,7 @@ use vector::ivf::v2::IVFIndex; pub(crate) mod append; pub(crate) mod cache; -pub(crate) mod prefilter; +pub mod prefilter; pub mod scalar; pub mod vector; @@ -519,7 +524,7 @@ impl DatasetIndexInternalExt for Dataset { // the index file is in lance format since version (0,2) // TODO: we need to change the legacy IVF_PQ to be in lance format - match (major_version, minor_version) { + let index = match (major_version, minor_version) { (0, 1) | (0, 0) => { let proto = open_index_proto(reader.as_ref()).await?; match &proto.implementation { @@ -557,14 +562,70 @@ impl DatasetIndexInternalExt for Dataset { } (0, 3) => { - let ivf = IVFIndex::::try_new( - self.object_store.clone(), - self.indices_dir(), - uuid.to_owned(), - Arc::downgrade(&self.session), - ) - .await?; - Ok(Arc::new(ivf)) + let scheduler = ScanScheduler::new(self.object_store.clone(), 16); + let file = scheduler.open_file(&index_file).await?; + let reader = + v2::reader::FileReader::try_open(file, None, Default::default()).await?; + let index_metadata = reader + .schema() + .metadata + .get(INDEX_METADATA_SCHEMA_KEY) + .ok_or(Error::Index { + message: "Index Metadata not found".to_owned(), + location: location!(), + })?; + let index_metadata: lance_index::IndexMetadata = + serde_json::from_str(index_metadata)?; + let field = self.schema().field(column).ok_or_else(|| Error::Index { + message: format!("Column {} does not exist in the schema", column), + location: location!(), + })?; + + let value_type = if let DataType::FixedSizeList(df, _) = field.data_type() { + Result::Ok(df.data_type().to_owned()) + } else { + return Err(Error::Index { + message: format!("Column {} is not a vector column", column), + location: location!(), + }); + }?; + match index_metadata.index_type.as_str() { + "FLAT" => match value_type { + DataType::Float16 | DataType::Float32 | DataType::Float64 => { + let ivf = IVFIndex::::try_new( + self.object_store.clone(), + self.indices_dir(), + uuid.to_owned(), + Arc::downgrade(&self.session), + ) + .await?; + Ok(Arc::new(ivf) as Arc) + } + _ => Err(Error::Index { + message: format!( + "the field type {} is not supported for FLAT index", + field.data_type() + ), + location: location!(), + }), + }, + + "HNSW" => { + let ivf = IVFIndex::::try_new( + self.object_store.clone(), + self.indices_dir(), + uuid.to_owned(), + Arc::downgrade(&self.session), + ) + .await?; + Ok(Arc::new(ivf) as Arc) + } + + _ => Err(Error::Index { + message: format!("Unsupported index type: {}", index_metadata.index_type), + location: location!(), + }), + } } _ => Err(Error::Index { @@ -572,7 +633,10 @@ impl DatasetIndexInternalExt for Dataset { .to_owned(), location: location!(), }), - } + }; + let index = index?; + self.session.index_cache.insert_vector(uuid, index.clone()); + Ok(index) } async fn scalar_index_info(&self) -> Result { diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 19aecf2424..6e32234042 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -10,7 +10,6 @@ use std::{any::Any, collections::HashMap}; pub mod builder; pub mod ivf; pub mod pq; -pub mod sq; mod traits; mod utils; @@ -21,7 +20,8 @@ use arrow::datatypes::Float32Type; use builder::IvfIndexBuilder; use lance_file::reader::FileReader; use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; -use lance_index::vector::ivf::storage::IvfData; +use lance_index::vector::hnsw::HNSW; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::ProductQuantizerImpl; use lance_index::vector::v3::shuffler::IvfShuffler; use lance_index::vector::{ @@ -40,7 +40,6 @@ use lance_linalg::distance::*; use lance_table::format::Index as IndexMetadata; use snafu::{location, Location}; use tracing::instrument; -use utils::get_vector_dim; use uuid::Uuid; use self::{ivf::*, pq::PQIndex}; @@ -79,20 +78,6 @@ impl VectorIndexParams { } } - // IVF_HNSW index - pub fn ivf_hnsw( - num_partitions: usize, - distance_type: DistanceType, - hnsw_params: HnswBuildParams, - ) -> Self { - let ivf_params = IvfBuildParams::new(num_partitions); - let stages = vec![StageParams::Ivf(ivf_params), StageParams::Hnsw(hnsw_params)]; - Self { - stages, - metric_type: distance_type, - } - } - /// Create index parameters for `IVF_PQ` index. /// /// Parameters @@ -146,11 +131,6 @@ impl VectorIndexParams { hnsw: HnswBuildParams, pq: PQBuildParams, ) -> Self { - let hnsw = match &hnsw.parallel_limit { - Some(_) => hnsw, - None => hnsw.parallel_limit(num_cpus::get().div_ceil(ivf.num_partitions)), - }; - let stages = vec![ StageParams::Ivf(ivf), StageParams::Hnsw(hnsw), @@ -170,11 +150,6 @@ impl VectorIndexParams { hnsw: HnswBuildParams, sq: SQBuildParams, ) -> Self { - let hnsw = match &hnsw.parallel_limit { - Some(_) => hnsw, - None => hnsw.parallel_limit(num_cpus::get().div_ceil(ivf.num_partitions)), - }; - let stages = vec![ StageParams::Ivf(ivf), StageParams::Hnsw(hnsw), @@ -237,7 +212,6 @@ pub(crate) async fn build_vector_index( params: &VectorIndexParams, ) -> Result<()> { let stages = ¶ms.stages; - let dim = get_vector_dim(dataset, column)?; if stages.is_empty() { return Err(Error::Index { @@ -255,21 +229,16 @@ pub(crate) async fn build_vector_index( }; let temp_dir = tempfile::tempdir()?; let path = temp_dir.path().to_str().unwrap().into(); - let shuffler = IvfShuffler::new( - dataset.object_store().clone(), - path, - ivf_params.num_partitions, - ); - let quantizer = FlatQuantizer::new(dim, params.metric_type); - IvfIndexBuilder::::new( + let shuffler = IvfShuffler::new(path, ivf_params.num_partitions); + IvfIndexBuilder::::new( dataset.clone(), column.to_owned(), dataset.indices_dir().child(uuid), params.metric_type, Box::new(shuffler), - ivf_params.clone(), + Some(ivf_params.clone()), + Some(()), (), - quantizer, )? .build() .await?; @@ -313,6 +282,10 @@ pub(crate) async fn build_vector_index( }); }; + let temp_dir = tempfile::tempdir()?; + let path = temp_dir.path().to_str().unwrap().into(); + let shuffler = IvfShuffler::new(path, ivf_params.num_partitions); + // with quantization if len > 2 { match stages.last().unwrap() { @@ -330,17 +303,18 @@ pub(crate) async fn build_vector_index( .await? } StageParams::SQ(sq_params) => { - build_ivf_hnsw_sq_index( - dataset, - column, - name, - uuid, + IvfIndexBuilder::::new( + dataset.clone(), + column.to_owned(), + dataset.indices_dir().child(uuid), params.metric_type, - ivf_params, - hnsw_params, - sq_params, - ) - .await? + Box::new(shuffler), + Some(ivf_params.clone()), + Some(sq_params.clone()), + hnsw_params.clone(), + )? + .build() + .await?; } _ => { return Err(Error::Index { @@ -431,7 +405,7 @@ pub(crate) async fn open_vector_index( location: location!(), }); } - let ivf = Ivf::try_from(ivf_pb)?; + let ivf = IvfModel::try_from(ivf_pb.to_owned())?; last_stage = Some(Arc::new(IVFIndex::try_new( dataset.session.clone(), uuid, @@ -468,7 +442,6 @@ pub(crate) async fn open_vector_index( }); } let idx = last_stage.unwrap(); - dataset.session.index_cache.insert_vector(uuid, idx.clone()); Ok(idx) } @@ -498,7 +471,7 @@ pub(crate) async fn open_vector_index_v2( .child(INDEX_AUXILIARY_FILE_NAME); let aux_reader = dataset.object_store().open(&aux_path).await?; - let ivf_data = IvfData::load(&reader).await?; + let ivf_data = IvfModel::load(&reader).await?; let options = HNSWIndexOptions { use_residual: true }; let hnsw = HNSWIndex::>::try_new( reader.object_reader.clone(), @@ -507,7 +480,7 @@ pub(crate) async fn open_vector_index_v2( ) .await?; let pb_ivf = pb::Ivf::try_from(&ivf_data)?; - let ivf = Ivf::try_from(&pb_ivf)?; + let ivf = IvfModel::try_from(pb_ivf)?; Arc::new(IVFIndex::try_new( dataset.session.clone(), @@ -526,7 +499,7 @@ pub(crate) async fn open_vector_index_v2( .child(INDEX_AUXILIARY_FILE_NAME); let aux_reader = dataset.object_store().open(&aux_path).await?; - let ivf_data = IvfData::load(&reader).await?; + let ivf_data = IvfModel::load(&reader).await?; let options = HNSWIndexOptions { use_residual: false, }; @@ -538,7 +511,7 @@ pub(crate) async fn open_vector_index_v2( ) .await?; let pb_ivf = pb::Ivf::try_from(&ivf_data)?; - let ivf = Ivf::try_from(&pb_ivf)?; + let ivf = IvfModel::try_from(pb_ivf)?; Arc::new(IVFIndex::try_new( dataset.session.clone(), @@ -573,10 +546,5 @@ pub(crate) async fn open_vector_index_v2( } }; - dataset - .session - .index_cache - .insert_vector(uuid, index.clone()); - Ok(index) } diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 6d3c194b25..536c1dda66 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -3,19 +3,24 @@ use std::sync::Arc; -use arrow_array::{FixedSizeListArray, RecordBatch}; +use arrow::array::AsArray; +use arrow_array::{RecordBatch, UInt64Array}; use futures::prelude::stream::{StreamExt, TryStreamExt}; use itertools::Itertools; -use lance_core::{Error, Result}; +use lance_arrow::RecordBatchExt; +use lance_core::{Error, Result, ROW_ID_FIELD}; use lance_encoding::decoder::{DecoderMiddlewareChain, FilterExpression}; use lance_file::v2::{reader::FileReader, writer::FileWriter}; +use lance_index::vector::flat::storage::FlatStorage; +use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::quantizer::QuantizerBuildParams; +use lance_index::vector::storage::STORAGE_METADATA_KEY; +use lance_index::vector::v3::shuffler::IvfShufflerReader; +use lance_index::vector::VectorIndex; use lance_index::{ pb, vector::{ - ivf::{ - storage::{IvfData, IVF_METADATA_KEY}, - IvfBuildParams, - }, + ivf::{storage::IVF_METADATA_KEY, IvfBuildParams}, quantizer::Quantization, storage::{StorageBuilder, VectorStore}, transform::Transformer, @@ -27,11 +32,14 @@ use lance_index::{ }, INDEX_AUXILIARY_FILE_NAME, INDEX_FILE_NAME, }; +use lance_index::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; +use lance_io::stream::RecordBatchStream; use lance_io::{ object_store::ObjectStore, scheduler::ScanScheduler, stream::RecordBatchStreamAdapter, ReadBatchParams, }; use lance_linalg::distance::DistanceType; +use log::info; use object_store::path::Path; use prost::Message; use snafu::{location, Location}; @@ -39,21 +47,39 @@ use tempfile::TempDir; use crate::Dataset; -use super::{utils, Ivf}; +use super::utils; +use super::v2::IVFIndex; +// Builder for IVF index +// The builder will train the IVF model and quantizer, shuffle the dataset, and build the sub index +// for each partition. +// To build the index for the whole dataset, call `build` method. +// To build the index for given IVF, quantizer, data stream, +// call `with_ivf`, `with_quantizer`, `shuffle_data`, and `build` in order. pub struct IvfIndexBuilder { dataset: Dataset, column: String, index_dir: Path, distance_type: DistanceType, shuffler: Arc, - ivf_params: IvfBuildParams, + // build params, only needed for building new IVF, quantizer + ivf_params: Option, + quantizer_params: Option, sub_index_params: S::BuildParams, - quantizer: Q, + _temp_dir: TempDir, // store this for keeping the temp dir alive temp_dir: Path, + + // fields will be set during build + ivf: Option, + quantizer: Option, + shuffle_reader: Option>, + partition_sizes: Vec<(usize, usize)>, + + // fields for merging indices + existing_indices: Vec>, } -impl IvfIndexBuilder { +impl IvfIndexBuilder { #[allow(clippy::too_many_arguments)] pub fn new( dataset: Dataset, @@ -61,12 +87,12 @@ impl IvfIndexBuilder { index_dir: Path, distance_type: DistanceType, shuffler: Box, - ivf_params: IvfBuildParams, + ivf_params: Option, + quantizer_params: Option, sub_index_params: S::BuildParams, - quantizer: Q, ) -> Result { let temp_dir = TempDir::new()?; - let temp_dir = Path::from(temp_dir.path().to_str().unwrap()); + let temp_dir_path = temp_dir.path().to_str().unwrap().into(); Ok(Self { dataset, column, @@ -74,74 +100,136 @@ impl IvfIndexBuilder { distance_type, shuffler: shuffler.into(), ivf_params, + quantizer_params, sub_index_params, - quantizer, - temp_dir, + _temp_dir: temp_dir, + temp_dir: temp_dir_path, + // fields will be set during build + ivf: None, + quantizer: None, + shuffle_reader: None, + partition_sizes: Vec::new(), + existing_indices: Vec::new(), }) } - pub async fn build(&self) -> Result<()> { - // step 1. train IVF - let ivf = self.load_or_build_ivf().await?; - - // step 2. shuffle data - let reader = self.shuffle_data(ivf.centroids.clone()).await?; - let partition_build_order = (0..self.ivf_params.num_partitions) - .map(|partition_id| reader.partiton_size(partition_id)) - .collect::>>()? - // sort by partition size in descending order - .into_iter() - .enumerate() - .map(|(idx, x)| (x, idx)) - .sorted() - .rev() - .map(|(_, idx)| idx) - .collect::>(); + pub fn new_incremental( + dataset: Dataset, + column: String, + index_dir: Path, + distance_type: DistanceType, + shuffler: Box, + sub_index_params: S::BuildParams, + ) -> Result { + Self::new( + dataset, + column, + index_dir, + distance_type, + shuffler, + None, + None, + sub_index_params, + ) + } - // step 3. build sub index - let mut partition_sizes = Vec::with_capacity(self.ivf_params.num_partitions); - for &partition in &partition_build_order { - let partition_data = reader.read_partition(partition).await?.ok_or(Error::io( - format!("partition {} is empty", partition).as_str(), - location!(), - ))?; - let batches = partition_data.try_collect::>().await?; - let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; + // build the index with the all data in the dataset, + pub async fn build(&mut self) -> Result<()> { + // step 1. train IVF & quantizer + if self.ivf.is_none() { + self.with_ivf(self.load_or_build_ivf().await?); + } + if self.quantizer.is_none() { + self.with_quantizer(self.load_or_build_quantizer().await?); + } - let sizes = self.build_partition(partition, &batch).await?; - partition_sizes.push(sizes); + // step 2. shuffle the dataset + if self.partition_sizes.is_empty() { + self.shuffle_dataset().await?; } + // step 3. build partitions + self.build_partitions().await?; + // step 4. merge all partitions - self.merge_partitions(ivf.centroids, partition_sizes) - .await?; + self.merge_partitions().await?; Ok(()) } - async fn load_or_build_ivf(&self) -> Result { + pub fn with_ivf(&mut self, ivf: IvfModel) -> &mut Self { + self.ivf = Some(ivf); + self + } + + pub fn with_quantizer(&mut self, quantizer: Q) -> &mut Self { + self.quantizer = Some(quantizer); + self + } + + pub fn with_existing_indices(&mut self, indices: Vec>) -> &mut Self { + self.existing_indices = indices; + self + } + + async fn load_or_build_ivf(&self) -> Result { + let ivf_params = self.ivf_params.as_ref().ok_or(Error::invalid_input( + "IVF build params not set", + location!(), + ))?; let dim = utils::get_vector_dim(&self.dataset, &self.column)?; super::build_ivf_model( &self.dataset, &self.column, dim, self.distance_type, - &self.ivf_params, + ivf_params, ) .await // TODO: load ivf model } - async fn shuffle_data(&self, centroids: FixedSizeListArray) -> Result> { - let transformer = Arc::new(lance_index::vector::ivf::new_ivf_with_quantizer( - centroids, - self.distance_type, - &self.column, - self.quantizer.clone().into(), - Some(0..self.ivf_params.num_partitions as u32), - )?); + async fn load_or_build_quantizer(&self) -> Result { + let quantizer_params = self.quantizer_params.as_ref().ok_or(Error::invalid_input( + "quantizer build params not set", + location!(), + ))?; + let sample_size_hint = quantizer_params.sample_size(); + + let start = std::time::Instant::now(); + info!( + "loading training data for quantizer. sample size: {}", + sample_size_hint + ); + let training_data = + utils::maybe_sample_training_data(&self.dataset, &self.column, sample_size_hint) + .await?; + info!( + "Finished loading training data in {:02} seconds", + start.elapsed().as_secs_f32() + ); + + // If metric type is cosine, normalize the training data, and after this point, + // treat the metric type as L2. + let (training_data, dt) = if self.distance_type == DistanceType::Cosine { + let training_data = lance_linalg::kernels::normalize_fsl(&training_data)?; + (training_data, DistanceType::L2) + } else { + (training_data, self.distance_type) + }; + + info!("Start to train quantizer"); + let start = std::time::Instant::now(); + let quantizer = Q::build(&training_data, dt, quantizer_params)?; + info!( + "Trained quantizer in {:02} seconds", + start.elapsed().as_secs_f32() + ); + Ok(quantizer) + } + async fn shuffle_dataset(&mut self) -> Result<()> { let stream = self .dataset .scan() @@ -150,78 +238,229 @@ impl IvfIndexBuilder { .with_row_id() .try_into_stream() .await?; + self.shuffle_data(Some(stream)).await?; + Ok(()) + } + // shuffle the unindexed data and exsiting indices + // data must be with schema | ROW_ID | vector_column | + // the shuffled data will be with schema | ROW_ID | PART_ID | code_column | + pub async fn shuffle_data( + &mut self, + data: Option, + ) -> Result<&mut Self> { + if data.is_none() { + return Ok(self); + } + let data = data.unwrap(); + + let ivf = self.ivf.as_ref().ok_or(Error::invalid_input( + "IVF not set before shuffle data", + location!(), + ))?; + let quantizer = self.quantizer.clone().ok_or(Error::invalid_input( + "quantizer not set before shuffle data", + location!(), + ))?; + + let transformer = Arc::new( + lance_index::vector::ivf::new_ivf_transformer_with_quantizer( + ivf.centroids.clone().unwrap(), + self.distance_type, + &self.column, + quantizer.into(), + Some(0..ivf.num_partitions() as u32), + )?, + ); let mut transformed_stream = Box::pin( - stream - .map(move |batch| { - let ivf_transformer = transformer.clone(); - tokio::spawn(async move { ivf_transformer.transform(&batch?) }) - }) - .buffered(num_cpus::get()) - .map(|x| x.unwrap()) - .peekable(), + data.map(move |batch| { + let ivf_transformer = transformer.clone(); + tokio::spawn(async move { ivf_transformer.transform(&batch?) }) + }) + .buffered(num_cpus::get()) + .map(|x| x.unwrap()) + .peekable(), ); let batch = transformed_stream.as_mut().peek().await; let schema = match batch { Some(Ok(b)) => b.schema(), Some(Err(e)) => panic!("do this better: error reading first batch: {:?}", e), - None => panic!("no data"), + None => { + log::info!("no data to shuffle"); + self.shuffle_reader = Some(Box::new(IvfShufflerReader::new( + self.dataset.object_store.clone(), + self.temp_dir.clone(), + vec![0; ivf.num_partitions()], + ))); + return Ok(self); + } }; - self.shuffler - .shuffle(Box::new(RecordBatchStreamAdapter::new( - schema, - transformed_stream, - ))) - .await + self.shuffle_reader = Some( + self.shuffler + .shuffle(Box::new(RecordBatchStreamAdapter::new( + schema, + transformed_stream, + ))) + .await?, + ); + + Ok(self) + } + + async fn build_partitions(&mut self) -> Result<&mut Self> { + let ivf = self.ivf.as_ref().ok_or(Error::invalid_input( + "IVF not set before building partitions", + location!(), + ))?; + + let reader = self.shuffle_reader.as_ref().ok_or(Error::invalid_input( + "shuffle reader not set before building partitions", + location!(), + ))?; + + let partition_build_order = (0..ivf.num_partitions()) + .map(|partition_id| reader.partition_size(partition_id)) + .collect::>>()? + // sort by partition size in descending order + .into_iter() + .enumerate() + .sorted_unstable_by(|(_, a), (_, b)| b.cmp(a)) + .map(|(idx, _)| idx) + .collect::>(); + + let mut partition_sizes = vec![(0, 0); ivf.num_partitions()]; + for (i, &partition) in partition_build_order.iter().enumerate() { + log::info!( + "building partition {}, progress {}/{}", + partition, + i + 1, + ivf.num_partitions(), + ); + let mut batches = Vec::new(); + for existing_index in self.existing_indices.iter() { + let existing_index = existing_index + .as_any() + .downcast_ref::>() + .ok_or(Error::invalid_input( + "existing index is not IVF index", + location!(), + ))?; + + let part_storage = existing_index.load_partition_storage(partition).await?; + batches.extend( + self.take_vectors(part_storage.row_ids().cloned().collect_vec().as_ref()) + .await?, + ); + } + + match reader.partition_size(partition)? { + 0 => continue, + _ => { + let partition_data = + reader.read_partition(partition).await?.ok_or(Error::io( + format!("partition {} is empty", partition).as_str(), + location!(), + ))?; + batches.extend(partition_data.try_collect::>().await?); + } + } + + let num_rows = batches.iter().map(|b| b.num_rows()).sum::(); + if num_rows == 0 { + continue; + } + let mut batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; + if self.distance_type == DistanceType::Cosine { + let vectors = batch + .column_by_name(&self.column) + .ok_or(Error::invalid_input( + format!("column {} not found", self.column).as_str(), + location!(), + ))? + .as_fixed_size_list(); + let vectors = lance_linalg::kernels::normalize_fsl(vectors)?; + batch = batch.replace_column_by_name(&self.column, Arc::new(vectors))?; + } + + let sizes = self.build_partition(partition, &batch).await?; + partition_sizes[partition] = sizes; + log::info!( + "partition {} built, progress {}/{}", + partition, + i + 1, + ivf.num_partitions() + ); + } + self.partition_sizes = partition_sizes; + Ok(self) } async fn build_partition(&self, part_id: usize, batch: &RecordBatch) -> Result<(usize, usize)> { - let object_store = ObjectStore::local(); + let quantizer = self.quantizer.clone().ok_or(Error::invalid_input( + "quantizer not set before building partition", + location!(), + ))?; // build quantized vector storage - let storage = StorageBuilder::new( - self.column.clone(), - self.distance_type, - self.quantizer.clone(), - ) - .build(batch)?; - let path = self.temp_dir.child(format!("storage_part{}", part_id)); - let writer = object_store.create(&path).await?; - let mut writer = FileWriter::try_new( - writer, - path.to_string(), - storage.schema().as_ref().try_into()?, - Default::default(), - )?; - for batch in storage.to_batches()? { - writer.write_batch(&batch).await?; - } - let storage_len = writer.finish().await? as usize; + let object_store = ObjectStore::local(); + let storage_len = { + let storage = StorageBuilder::new(self.column.clone(), self.distance_type, quantizer) + .build(batch)?; + let path = self.temp_dir.child(format!("storage_part{}", part_id)); + let writer = object_store.create(&path).await?; + let mut writer = FileWriter::try_new( + writer, + path.to_string(), + storage.schema().as_ref().try_into()?, + Default::default(), + )?; + for batch in storage.to_batches()? { + writer.write_batch(&batch).await?; + } + writer.finish().await? as usize + }; // build the sub index, with in-memory storage - let sub_index = S::index_vectors(&storage, self.sub_index_params.clone())?; - let path = self.temp_dir.child(format!("index_part{}", part_id)); - let writer = object_store.create(&path).await?; - let index_batch = sub_index.to_batch()?; - let mut writer = FileWriter::try_new( - writer, - path.to_string(), - index_batch.schema_ref().as_ref().try_into()?, - Default::default(), - )?; - writer.write_batch(&index_batch).await?; - let index_len = writer.finish().await? as usize; + let index_len = { + let distance_type = match self.distance_type { + DistanceType::Cosine | DistanceType::Dot => DistanceType::L2, + _ => self.distance_type, + }; + let vectors = batch[&self.column].as_fixed_size_list(); + let flat_storage = FlatStorage::new(vectors.clone(), distance_type); + let sub_index = S::index_vectors(&flat_storage, self.sub_index_params.clone())?; + let path = self.temp_dir.child(format!("index_part{}", part_id)); + let writer = object_store.create(&path).await?; + let index_batch = sub_index.to_batch()?; + let mut writer = FileWriter::try_new( + writer, + path.to_string(), + index_batch.schema_ref().as_ref().try_into()?, + Default::default(), + )?; + writer.write_batch(&index_batch).await?; + writer.finish().await? as usize + }; Ok((storage_len, index_len)) } - async fn merge_partitions( - &self, - centroids: FixedSizeListArray, - partition_sizes: Vec<(usize, usize)>, - ) -> Result<()> { + async fn merge_partitions(&mut self) -> Result<()> { + let ivf = self.ivf.as_ref().ok_or(Error::invalid_input( + "IVF not set before merge partitions", + location!(), + ))?; + let quantizer = self.quantizer.clone().ok_or(Error::invalid_input( + "quantizer not set before merge partitions", + location!(), + ))?; + let partition_sizes = std::mem::take(&mut self.partition_sizes); + if partition_sizes.is_empty() { + return Err(Error::invalid_input("no partition to merge", location!())); + } + // prepare the final writers let storage_path = self.index_dir.child(INDEX_AUXILIARY_FILE_NAME); let index_path = self.index_dir.child(INDEX_FILE_NAME); @@ -234,12 +473,16 @@ impl IvfIndexBuilder { )?; // maintain the IVF partitions - let mut storage_ivf = IvfData::empty(); - let mut index_ivf = IvfData::with_centroids(Arc::new(centroids)); + let mut storage_ivf = IvfModel::empty(); + let mut index_ivf = IvfModel::new(ivf.centroids.clone().unwrap()); + let mut partition_storage_metadata = Vec::with_capacity(partition_sizes.len()); + let mut partition_index_metadata = Vec::with_capacity(partition_sizes.len()); let scheduler = ScanScheduler::new(Arc::new(ObjectStore::local()), 64); for (part_id, (storage_size, index_size)) in partition_sizes.into_iter().enumerate() { + log::info!("merging partition {}/{}", part_id, ivf.num_partitions()); if storage_size == 0 { - storage_ivf.add_partition(0) + storage_ivf.add_partition(0); + partition_storage_metadata.push(quantizer.metadata(None)?.to_string()); } else { let storage_part_path = self.temp_dir.child(format!("storage_part{}", part_id)); let reader = FileReader::try_open( @@ -268,10 +511,19 @@ impl IvfIndexBuilder { } storage_writer.as_mut().unwrap().write_batch(&batch).await?; storage_ivf.add_partition(batch.num_rows() as u32); + partition_storage_metadata.push( + reader + .schema() + .metadata + .get(STORAGE_METADATA_KEY) + .cloned() + .unwrap_or_default(), + ); } if index_size == 0 { - index_ivf.add_partition(0) + index_ivf.add_partition(0); + partition_index_metadata.push(String::new()); } else { let index_part_path = self.temp_dir.child(format!("index_part{}", part_id)); let reader = FileReader::try_open( @@ -292,7 +544,16 @@ impl IvfIndexBuilder { let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; index_writer.write_batch(&batch).await?; index_ivf.add_partition(batch.num_rows() as u32); + partition_index_metadata.push( + reader + .schema() + .metadata + .get(S::metadata_key()) + .cloned() + .unwrap_or_default(), + ); } + log::info!("merged partition {}/{}", part_id, ivf.num_partitions()); } let mut storage_writer = storage_writer.unwrap(); @@ -303,22 +564,52 @@ impl IvfIndexBuilder { .await?; storage_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string()); storage_writer.add_schema_metadata( - Q::metadata_key(), - self.quantizer.metadata(None)?.to_string(), + STORAGE_METADATA_KEY, + serde_json::to_string(&partition_storage_metadata)?, ); let index_ivf_pb = pb::Ivf::try_from(&index_ivf)?; - index_writer.add_schema_metadata(DISTANCE_TYPE_KEY, self.distance_type.to_string()); + let index_metadata = IndexMetadata { + index_type: S::name().to_string(), + distance_type: self.distance_type.to_string(), + }; + index_writer.add_schema_metadata( + INDEX_METADATA_SCHEMA_KEY, + serde_json::to_string(&index_metadata)?, + ); let ivf_buffer_pos = index_writer .add_global_buffer(index_ivf_pb.encode_to_vec().into()) .await?; index_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string()); + index_writer.add_schema_metadata( + S::metadata_key(), + serde_json::to_string(&partition_index_metadata)?, + ); storage_writer.finish().await?; index_writer.finish().await?; Ok(()) } + + // take vectors from the dataset + // used for reading vectors from existing indices + async fn take_vectors(&self, row_ids: &[u64]) -> Result> { + let column = self.column.clone(); + let object_store = self.dataset.object_store().clone(); + let projection = self.dataset.schema().project(&[column.as_str()])?; + // arrow uses i32 for index, so we chunk the row ids to avoid large batch causing overflow + let mut batches = Vec::new(); + for chunk in row_ids.chunks(object_store.block_size()) { + let batch = self.dataset.take_rows(chunk, &projection).await?; + let batch = batch.try_with_column( + ROW_ID_FIELD.clone(), + Arc::new(UInt64Array::from(chunk.to_vec())), + )?; + batches.push(batch); + } + Ok(batches) + } } #[cfg(test)] @@ -331,6 +622,8 @@ mod tests { use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::hnsw::HNSW; + use lance_index::vector::sq::builder::SQBuildParams; + use lance_index::vector::sq::ScalarQuantizer; use lance_index::vector::{ flat::index::{FlatIndex, FlatQuantizer}, ivf::IvfBuildParams, @@ -381,56 +674,50 @@ mod tests { let ivf_params = IvfBuildParams::default(); let index_dir = tempdir().unwrap(); let index_dir = Path::from(index_dir.path().to_str().unwrap()); - let shuffler = IvfShuffler::new( - dataset.object_store().clone(), - index_dir.child("shuffled"), - ivf_params.num_partitions, - ); + let shuffler = IvfShuffler::new(index_dir.child("shuffled"), ivf_params.num_partitions); - let fq = FlatQuantizer::new(DIM, DistanceType::L2); - let builder = super::IvfIndexBuilder::::new( + super::IvfIndexBuilder::::new( dataset, "vector".to_owned(), index_dir, DistanceType::L2, Box::new(shuffler), - ivf_params, + Some(ivf_params), + Some(()), (), - fq, ) + .unwrap() + .build() + .await .unwrap(); - - builder.build().await.unwrap(); } #[tokio::test] - async fn test_build_ivf_hnsw() { + async fn test_build_ivf_hnsw_sq() { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); let (dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; let ivf_params = IvfBuildParams::default(); let hnsw_params = HnswBuildParams::default(); + let sq_params = SQBuildParams::default(); let index_dir = tempdir().unwrap(); let index_dir = Path::from(index_dir.path().to_str().unwrap()); - let shuffler = IvfShuffler::new( - dataset.object_store().clone(), - index_dir.child("shuffled"), - ivf_params.num_partitions, - ); + let shuffler = IvfShuffler::new(index_dir.child("shuffled"), ivf_params.num_partitions); - let fq = FlatQuantizer::new(DIM, DistanceType::L2); - let builder = super::IvfIndexBuilder::::new( + super::IvfIndexBuilder::::new( dataset, "vector".to_owned(), index_dir, DistanceType::L2, Box::new(shuffler), - ivf_params, + Some(ivf_params), + Some(sq_params), hnsw_params, - fq, ) + .unwrap() + .build() + .await .unwrap(); - builder.build().await.unwrap(); } } diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index cd7498eadc..0214e83a99 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -14,11 +14,14 @@ mod test { use approx::assert_relative_eq; use arrow::array::AsArray; - use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch}; + use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, UInt32Array}; use arrow_schema::{DataType, Field, Schema}; use async_trait::async_trait; use deepsize::{Context, DeepSizeOf}; use lance_arrow::FixedSizeListArrayExt; + use lance_index::vector::ivf::storage::IvfModel; + use lance_index::vector::quantizer::{QuantizationType, Quantizer}; + use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{vector::Query, Index, IndexType}; use lance_io::{local::LocalObjectReader, traits::Reader}; use lance_linalg::distance::MetricType; @@ -29,7 +32,7 @@ mod test { use crate::{ index::{ prefilter::{DatasetPreFilter, PreFilter}, - vector::ivf::{IVFIndex, Ivf}, + vector::ivf::IVFIndex, }, session::Session, Result, @@ -63,6 +66,10 @@ mod test { self } + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + /// Retrieve index statistics as a JSON Value fn statistics(&self) -> Result { Ok(serde_json::Value::Null) @@ -93,6 +100,19 @@ mod test { Ok(self.ret_val.clone()) } + fn find_partitions(&self, _: &Query) -> Result { + unimplemented!("only for IVF") + } + + async fn search_in_partition( + &self, + _: usize, + _: &Query, + _: Arc, + ) -> Result { + unimplemented!("only for IVF") + } + fn is_loadable(&self) -> bool { true } @@ -122,6 +142,18 @@ mod test { Ok(()) } + fn ivf_model(&self) -> IvfModel { + unimplemented!("only for IVF") + } + fn quantizer(&self) -> Quantizer { + unimplemented!("only for IVF") + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + unimplemented!("only for IVF") + } + /// The metric type of this vector index. fn metric_type(&self) -> MetricType { self.metric_type @@ -132,10 +164,10 @@ mod test { async fn test_ivf_residual_handling() { let centroids = Float32Array::from_iter(vec![1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0]); let centroids = FixedSizeListArray::try_new_from_values(centroids, 2).unwrap(); - let mut ivf = Ivf::new(centroids); + let mut ivf = IvfModel::new(centroids); // Add 4 partitions for _ in 0..4 { - ivf.add_partition(0, 0); + ivf.add_partition(0); } // hold on to this pointer, because the index only holds a weak reference let session = Arc::new(Session::default()); diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 573e6d3a2a..c41ede6d80 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -13,7 +13,7 @@ use arrow_arith::numeric::sub; use arrow_array::{ cast::{as_struct_array, AsArray}, types::{Float16Type, Float32Type, Float64Type}, - Array, FixedSizeListArray, Float32Array, RecordBatch, StructArray, UInt32Array, + Array, FixedSizeListArray, RecordBatch, StructArray, UInt32Array, }; use arrow_ord::sort::sort_to_indices; use arrow_schema::{DataType, Schema}; @@ -30,20 +30,22 @@ use lance_file::{ format::MAGIC, writer::{FileWriter, FileWriterOptions}, }; -use lance_index::vector::v3::subindex::IvfSubIndex; +use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::quantizer::QuantizationType; +use lance_index::vector::v3::shuffler::IvfShuffler; +use lance_index::vector::v3::subindex::{IvfSubIndex, SubIndexType}; use lance_index::{ optimize::OptimizeOptions, vector::{ hnsw::{builder::HnswBuildParams, HNSWIndex, HNSW}, ivf::{ - builder::load_precomputed_partitions, - shuffler::shuffle_dataset, - storage::{IvfData, IVF_PARTITION_KEY}, - IvfBuildParams, + builder::load_precomputed_partitions, shuffler::shuffle_dataset, + storage::IVF_PARTITION_KEY, IvfBuildParams, }, pq::{PQBuildParams, ProductQuantizer}, quantizer::{Quantization, QuantizationMetadata, Quantizer}, - sq::{builder::SQBuildParams, ScalarQuantizer}, + sq::ScalarQuantizer, Query, VectorIndex, DIST_COL, }, Index, IndexMetadata, IndexType, INDEX_AUXILIARY_FILE_NAME, INDEX_METADATA_SCHEMA_KEY, @@ -63,7 +65,7 @@ use lance_linalg::{ distance::{DistanceType, Dot, MetricType, L2}, MatrixView, }; -use log::{debug, info}; +use log::info; use object_store::path::Path; use rand::{rngs::SmallRng, SeedableRng}; use roaring::RoaringBitmap; @@ -75,11 +77,12 @@ use uuid::Uuid; use self::io::write_hnsw_quantization_index_partitions; +use super::builder::IvfIndexBuilder; use super::{ pq::{build_pq_model, PQIndex}, utils::maybe_sample_training_data, }; -use crate::{dataset::builder::DatasetBuilder, index::vector::sq::build_sq_model}; +use crate::dataset::builder::DatasetBuilder; use crate::{ dataset::Dataset, index::{ @@ -100,7 +103,7 @@ pub struct IVFIndex { uuid: String, /// Ivf model - ivf: Ivf, + ivf: IvfModel, reader: Arc, @@ -129,7 +132,7 @@ impl IVFIndex { pub(crate) fn try_new( session: Arc, uuid: &str, - ivf: Ivf, + ivf: IvfModel, reader: Arc, sub_index: Arc, metric_type: MetricType, @@ -171,22 +174,26 @@ impl IVFIndex { let part_index = if let Some(part_idx) = session.index_cache.get_vector(&cache_key) { part_idx } else { - if partition_id >= self.ivf.lengths.len() { + if partition_id >= self.ivf.num_partitions() { return Err(Error::Index { message: format!( "partition id {} is out of range of {} partitions", partition_id, - self.ivf.lengths.len() + self.ivf.num_partitions() ), location: location!(), }); } - let offset = self.ivf.offsets[partition_id]; - let length = self.ivf.lengths[partition_id] as usize; + let range = self.ivf.row_range(partition_id); let idx = self .sub_index - .load_partition(self.reader.clone(), offset, length, partition_id) + .load_partition( + self.reader.clone(), + range.start, + range.end - range.start, + partition_id, + ) .await?; let idx: Arc = idx.into(); if write_cache { @@ -202,7 +209,7 @@ impl IVFIndex { /// Internal API with no stability guarantees. pub fn preprocess_query(&self, partition_id: usize, query: &Query) -> Result { if self.sub_index.use_residual() { - let partition_centroids = self.ivf.centroids.value(partition_id); + let partition_centroids = self.ivf.centroids.as_ref().unwrap().value(partition_id); let residual_key = sub(&query.key, &partition_centroids)?; let mut part_query = query.clone(); part_query.key = residual_key; @@ -211,34 +218,6 @@ impl IVFIndex { Ok(query.clone()) } } - - pub(crate) async fn search_in_partition( - &self, - partition_id: usize, - query: &Query, - pre_filter: Arc, - ) -> Result { - let part_index = self.load_partition(partition_id, true).await?; - - let query = self.preprocess_query(partition_id, query)?; - let batch = part_index.search(&query, pre_filter).await?; - Ok(batch) - } - - /// find the IVF partitions ids given the query vector. - /// - /// Internal API with no stability guarantees. - /// - /// Assumes the query vector is normalized if the metric type is cosine. - pub fn find_partitions(&self, query: &Query) -> Result { - let mt = if self.metric_type == MetricType::Cosine { - MetricType::L2 - } else { - self.metric_type - }; - - self.ivf.find_partitions(&query.key, query.nprobes, mt) - } } impl std::fmt::Debug for IVFIndex { @@ -265,6 +244,19 @@ pub(crate) async fn optimize_vector_indices( }); } + // try cast to v1 IVFIndex, + // fallback to v2 IVFIndex if it's not v1 IVFIndex + if !existing_indices[0].as_any().is::() { + return optimize_vector_indices_v2( + &dataset, + unindexed, + vector_column, + existing_indices, + options, + ) + .await; + } + let new_uuid = Uuid::new_v4(); let object_store = dataset.object_store(); let index_file = dataset @@ -277,7 +269,7 @@ pub(crate) async fn optimize_vector_indices( .as_any() .downcast_ref::() .ok_or(Error::Index { - message: "optimizing vector index: first index is not IVF".to_string(), + message: "optimizing vector index: the first index isn't IVF".to_string(), location: location!(), })?; @@ -325,6 +317,98 @@ pub(crate) async fn optimize_vector_indices( Ok((new_uuid, merged)) } +pub(crate) async fn optimize_vector_indices_v2( + dataset: &Dataset, + unindexed: Option, + vector_column: &str, + existing_indices: &[Arc], + options: &OptimizeOptions, +) -> Result<(Uuid, usize)> { + // Sanity check the indices + if existing_indices.is_empty() { + return Err(Error::Index { + message: "optimizing vector index: no existing index found".to_string(), + location: location!(), + }); + } + let existing_indices = existing_indices + .iter() + .cloned() + .map(|idx| idx.as_vector_index()) + .collect::>>()?; + + let new_uuid = Uuid::new_v4(); + let index_dir = dataset.indices_dir().child(new_uuid.to_string()); + let ivf_model = existing_indices[0].ivf_model(); + let quantizer = existing_indices[0].quantizer(); + let distance_type = existing_indices[0].metric_type(); + let num_partitions = ivf_model.num_partitions(); + let index_type = existing_indices[0].sub_index_type(); + + let temp_dir = tempfile::tempdir()?; + let temp_dir = temp_dir.path().to_str().unwrap().into(); + let shuffler = Box::new(IvfShuffler::new(temp_dir, num_partitions)); + let start_pos = if options.num_indices_to_merge > existing_indices.len() { + 0 + } else { + existing_indices.len() - options.num_indices_to_merge + }; + let indices_to_merge = existing_indices[start_pos..].to_vec(); + let merged_num = indices_to_merge.len(); + match index_type { + (SubIndexType::Flat, QuantizationType::Flat) => { + IvfIndexBuilder::::new_incremental( + dataset.clone(), + vector_column.to_owned(), + index_dir, + distance_type, + shuffler, + (), + )? + .with_ivf(ivf_model) + .with_quantizer(quantizer.try_into()?) + .with_existing_indices(indices_to_merge) + .shuffle_data(unindexed) + .await? + .build() + .await?; + } + + (SubIndexType::Hnsw, QuantizationType::Scalar) => { + IvfIndexBuilder::::new( + dataset.clone(), + vector_column.to_owned(), + index_dir, + distance_type, + shuffler, + None, + None, + // TODO: get the HNSW parameters from the existing indices + HnswBuildParams::default(), + )? + .with_ivf(ivf_model) + .with_quantizer(quantizer.try_into()?) + .with_existing_indices(indices_to_merge) + .shuffle_data(unindexed) + .await? + .build() + .await?; + } + + (sub_index_type, quantizer_type) => { + return Err(Error::Index { + message: format!( + "optimizing vector index: unsupported index type IVF_{}_{}", + sub_index_type, quantizer_type + ), + location: location!(), + }); + } + } + + Ok((new_uuid, merged_num)) +} + #[allow(clippy::too_many_arguments)] async fn optimize_ivf_pq_indices( first_idx: &IVFIndex, @@ -340,8 +424,8 @@ async fn optimize_ivf_pq_indices( let dim = first_idx.ivf.dimension(); // TODO: merge `lance::vector::ivf::IVF` and `lance-index::vector::ivf::Ivf`` implementations. - let ivf = lance_index::vector::ivf::Ivf::with_pq( - first_idx.ivf.centroids.clone(), + let ivf = lance_index::vector::ivf::IvfTransformer::with_pq( + first_idx.ivf.centroids.clone().unwrap(), metric_type, vector_column, pq_index.pq.clone(), @@ -349,10 +433,10 @@ async fn optimize_ivf_pq_indices( ); // Shuffled un-indexed data with partition. - let shuffled = if let Some(stream) = unindexed { - Some( + let shuffled = match unindexed { + Some(unindexed) => Some( shuffle_dataset( - stream, + unindexed, vector_column, ivf.into(), None, @@ -362,12 +446,11 @@ async fn optimize_ivf_pq_indices( None, ) .await?, - ) - } else { - None + ), + None => None, }; - let mut ivf_mut = Ivf::new(first_idx.ivf.centroids.clone()); + let mut ivf_mut = IvfModel::new(first_idx.ivf.centroids.clone().unwrap()); let start_pos = if options.num_indices_to_merge > existing_indices.len() { 0 @@ -420,8 +503,8 @@ async fn optimize_ivf_hnsw_indices( ) -> Result { let distance_type = first_idx.metric_type; let quantizer = hnsw_index.quantizer().clone(); - let ivf = lance_index::vector::ivf::new_ivf_with_quantizer( - first_idx.ivf.centroids.clone(), + let ivf = lance_index::vector::ivf::new_ivf_transformer_with_quantizer( + first_idx.ivf.centroids.clone().unwrap(), distance_type, vector_column, quantizer.clone(), @@ -429,10 +512,10 @@ async fn optimize_ivf_hnsw_indices( )?; // Shuffled un-indexed data with partition. - let shuffled = if let Some(stream) = unindexed { - Some( + let unindexed_data = match unindexed { + Some(unindexed) => Some( shuffle_dataset( - stream, + unindexed, vector_column, Arc::new(ivf), None, @@ -442,12 +525,11 @@ async fn optimize_ivf_hnsw_indices( None, ) .await?, - ) - } else { - None + ), + None => None, }; - let mut ivf_mut = Ivf::new(first_idx.ivf.centroids.clone()); + let mut ivf_mut = IvfModel::new(first_idx.ivf.centroids.clone().unwrap()); let start_pos = if options.num_indices_to_merge > existing_indices.len() { 0 @@ -549,7 +631,7 @@ async fn optimize_ivf_hnsw_indices( Some(&mut aux_writer), &mut ivf_mut, quantizer, - shuffled, + unindexed_data, Some(&indices_to_merge), ) .await?; @@ -558,12 +640,7 @@ async fn optimize_ivf_hnsw_indices( let hnsw_metadata_json = json!(hnsw_metadata); writer.add_metadata(IVF_PARTITION_KEY, &hnsw_metadata_json.to_string()); - // Convert ['Ivf'] to [`IvfData`] for new index format - let mut ivf_data = IvfData::with_centroids(Arc::new(ivf_mut.centroids.clone())); - for length in ivf_mut.lengths { - ivf_data.add_partition(length); - } - ivf_data.write(&mut writer).await?; + ivf_mut.write(&mut writer).await?; writer.finish().await?; // Write the aux file @@ -637,19 +714,22 @@ impl Index for IVFIndex { self } + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + fn index_type(&self) -> IndexType { IndexType::Vector } fn statistics(&self) -> Result { - let partitions_statistics = self - .ivf - .lengths - .iter() - .map(|&len| IvfIndexPartitionStatistics { size: len }) + let partitions_statistics = (0..self.ivf.num_partitions()) + .map(|part_id| IvfIndexPartitionStatistics { + size: self.ivf.partition_size(part_id) as u32, + }) .collect::>(); - let centroid_vecs = centroids_to_vectors(&self.ivf.centroids)?; + let centroid_vecs = centroids_to_vectors(self.ivf.centroids.as_ref().unwrap())?; Ok(serde_json::to_value(IvfIndexStatistics { index_type: "IVF".to_string(), @@ -712,6 +792,34 @@ impl VectorIndex for IVFIndex { Ok(as_struct_array(&taken_distances).into()) } + /// find the IVF partitions ids given the query vector. + /// + /// Internal API with no stability guarantees. + /// + /// Assumes the query vector is normalized if the metric type is cosine. + fn find_partitions(&self, query: &Query) -> Result { + let mt = if self.metric_type == MetricType::Cosine { + MetricType::L2 + } else { + self.metric_type + }; + + self.ivf.find_partitions(&query.key, query.nprobes, mt) + } + + async fn search_in_partition( + &self, + partition_id: usize, + query: &Query, + pre_filter: Arc, + ) -> Result { + let part_index = self.load_partition(partition_id, true).await?; + + let query = self.preprocess_query(partition_id, query)?; + let batch = part_index.search(&query, pre_filter).await?; + Ok(batch) + } + fn is_loadable(&self) -> bool { false } @@ -753,6 +861,19 @@ impl VectorIndex for IVFIndex { }) } + fn ivf_model(&self) -> IvfModel { + self.ivf.clone() + } + + fn quantizer(&self) -> Quantizer { + unimplemented!("only for v2 IVFIndex") + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + unimplemented!("only for v2 IVFIndex") + } + fn metric_type(&self) -> MetricType { self.metric_type } @@ -779,7 +900,7 @@ pub struct IvfPQIndexMetadata { pub(crate) metric_type: MetricType, /// IVF model - pub(crate) ivf: Ivf, + pub(crate) ivf: IvfModel, /// Product Quantizer pub(crate) pq: Arc, @@ -835,112 +956,6 @@ impl TryFrom<&IvfPQIndexMetadata> for pb::Index { }) } } -/// Ivf Model -#[derive(Debug, Clone)] -pub(crate) struct Ivf { - /// Centroids of each partition. - /// - /// It is a 2-D `(num_partitions * dimension)` of vector array. - pub(crate) centroids: FixedSizeListArray, - - /// Offset of each partition in the file. - offsets: Vec, - - /// Number of vectors in each partition. - lengths: Vec, -} - -impl Ivf { - pub(super) fn new(centroids: FixedSizeListArray) -> Self { - Self { - centroids, - offsets: vec![], - lengths: vec![], - } - } - - /// Ivf model dimension. - pub(super) fn dimension(&self) -> usize { - self.centroids.value_length() as usize - } - - /// Number of IVF partitions. - fn num_partitions(&self) -> usize { - self.centroids.len() - } - - /// Use the query vector to find `nprobes` closest partitions. - fn find_partitions( - &self, - query: &dyn Array, - nprobes: usize, - metric_type: MetricType, - ) -> Result { - let internal = - lance_index::vector::ivf::new_ivf(self.centroids.clone(), metric_type, vec![]); - internal.find_partitions(query, nprobes) - } - - /// Add the offset and length of one partition. - pub(super) fn add_partition(&mut self, offset: usize, len: u32) { - self.offsets.push(offset); - self.lengths.push(len); - } -} - -/// Convert IvfModel to protobuf. -impl TryFrom<&Ivf> for pb::Ivf { - type Error = Error; - - fn try_from(ivf: &Ivf) -> Result { - if ivf.offsets.len() != ivf.centroids.len() { - return Err(Error::io( - "Ivf model has not been populated".to_string(), - location!(), - )); - } - Ok(Self { - centroids: vec![], - offsets: ivf.offsets.iter().map(|o| *o as u64).collect(), - lengths: ivf.lengths.clone(), - centroids_tensor: Some((&ivf.centroids).try_into()?), - }) - } -} - -/// Convert IvfModel to protobuf. -impl TryFrom<&pb::Ivf> for Ivf { - type Error = Error; - - fn try_from(proto: &pb::Ivf) -> Result { - let centroids = if let Some(tensor) = proto.centroids_tensor.as_ref() { - debug!("Ivf: loading IVF centroids from index format v2"); - FixedSizeListArray::try_from(tensor)? - } else { - debug!("Ivf: loading IVF centroids from index format v1"); - // For backward-compatibility - let f32_centroids = Float32Array::from(proto.centroids.clone()); - let dimension = f32_centroids.len() / proto.lengths.len(); - FixedSizeListArray::try_new_from_values(f32_centroids, dimension as i32)? - }; - - let mut ivf = Self { - centroids, - offsets: proto.offsets.iter().map(|o| *o as usize).collect(), - lengths: proto.lengths.clone(), - }; - - if ivf.offsets.is_empty() && !ivf.lengths.is_empty() { - let mut offset = 0; - for len in &ivf.lengths { - ivf.offsets.push(offset); - offset += *len as usize; - } - } - - Ok(ivf) - } -} fn sanity_check<'a>(dataset: &'a Dataset, column: &str) -> Result<&'a Field> { let Some(field) = dataset.schema().field(column) else { @@ -1035,7 +1050,7 @@ pub(super) async fn build_ivf_model( dim: usize, metric_type: MetricType, params: &IvfBuildParams, -) -> Result { +) -> Result { if let Some(centroids) = params.centroids.as_ref() { info!("Pre-computed IVF centroids is provided, skip IVF training"); if centroids.values().len() != params.num_partitions * dim { @@ -1048,7 +1063,7 @@ pub(super) async fn build_ivf_model( location: location!(), }); } - return Ok(Ivf::new(centroids.as_ref().clone())); + return Ok(IvfModel::new(centroids.as_ref().clone())); } let sample_size_hint = params.num_partitions * params.sample_rate; @@ -1088,7 +1103,7 @@ async fn build_ivf_model_and_pq( metric_type: MetricType, ivf_params: &IvfBuildParams, pq_params: &PQBuildParams, -) -> Result<(Ivf, Arc)> { +) -> Result<(IvfModel, Arc)> { sanity_check_params(ivf_params, pq_params)?; info!( @@ -1122,40 +1137,6 @@ async fn build_ivf_model_and_pq( Ok((ivf_model, pq)) } -async fn build_ivf_model_and_sq( - dataset: &Dataset, - column: &str, - metric_type: MetricType, - ivf_params: &IvfBuildParams, - sq_params: &SQBuildParams, -) -> Result<(Ivf, ScalarQuantizer)> { - sanity_check_ivf_params(ivf_params)?; - - info!( - "Building vector index: IVF{},SQ{}, metric={}", - ivf_params.num_partitions, sq_params.num_bits, metric_type, - ); - - let field = sanity_check(dataset, column)?; - let dim = if let DataType::FixedSizeList(_, d) = field.data_type() { - d as usize - } else { - return Err(Error::Index { - message: format!( - "VectorIndex requires the column data type to be fixed size list of floats, got {}", - field.data_type() - ), - location: location!(), - }); - }; - - let ivf_model = build_ivf_model(dataset, column, dim, metric_type, ivf_params).await?; - - let sq = build_sq_model(dataset, column, metric_type, sq_params).await?; - - Ok((ivf_model, sq)) -} - async fn scan_index_field_stream( dataset: &Dataset, column: &str, @@ -1250,41 +1231,6 @@ pub async fn build_ivf_hnsw_pq_index( .await } -#[allow(clippy::too_many_arguments)] -pub async fn build_ivf_hnsw_sq_index( - dataset: &Dataset, - column: &str, - index_name: &str, - uuid: &str, - metric_type: MetricType, - ivf_params: &IvfBuildParams, - hnsw_params: &HnswBuildParams, - sq_params: &SQBuildParams, -) -> Result<()> { - let (ivf_model, sq) = - build_ivf_model_and_sq(dataset, column, metric_type, ivf_params, sq_params).await?; - let stream = scan_index_field_stream(dataset, column).await?; - let precomputed_partitions = load_precomputed_partitions_if_available(ivf_params).await?; - - write_ivf_hnsw_file( - dataset, - column, - index_name, - uuid, - &[], - ivf_model, - Quantizer::Scalar(sq), - metric_type, - hnsw_params, - stream, - precomputed_partitions, - ivf_params.shuffle_partition_batches, - ivf_params.shuffle_partition_concurrency, - ivf_params.precomputed_shuffle_buffers.clone(), - ) - .await -} - struct RemapPageTask { offset: usize, length: u32, @@ -1317,7 +1263,7 @@ impl RemapPageTask { Ok(self) } - async fn write(self, writer: &mut ObjectWriter, ivf: &mut Ivf) -> Result<()> { + async fn write(self, writer: &mut ObjectWriter, ivf: &mut IvfModel) -> Result<()> { let page = self.page.as_ref().expect("Load was not called"); let page: &PQIndex = page .as_any() @@ -1367,7 +1313,7 @@ pub(crate) async fn remap_index_file( .map(|task| task.load_and_remap(reader.clone(), index, mapping)) .buffered(num_cpus::get()); - let mut ivf = Ivf { + let mut ivf = IvfModel { centroids: index.ivf.centroids.clone(), offsets: Vec::with_capacity(index.ivf.offsets.len()), lengths: Vec::with_capacity(index.ivf.lengths.len()), @@ -1415,7 +1361,7 @@ async fn write_ivf_pq_file( index_name: &str, uuid: &str, transformers: &[Box], - mut ivf: Ivf, + mut ivf: IvfModel, pq: Arc, metric_type: MetricType, stream: impl RecordBatchStream + Unpin + 'static, @@ -1481,7 +1427,7 @@ async fn write_ivf_hnsw_file( _index_name: &str, uuid: &str, _transformers: &[Box], - mut ivf: Ivf, + mut ivf: IvfModel, quantizer: Quantizer, distance_type: DistanceType, hnsw_params: &HnswBuildParams, @@ -1598,12 +1544,7 @@ async fn write_ivf_hnsw_file( let hnsw_metadata_json = json!(hnsw_metadata); writer.add_metadata(IVF_PARTITION_KEY, &hnsw_metadata_json.to_string()); - // Convert ['Ivf'] to [`IvfData`] for new index format - let mut ivf_data = IvfData::with_centroids(Arc::new(ivf.centroids.clone())); - for length in ivf.lengths { - ivf_data.add_partition(length); - } - ivf_data.write(&mut writer).await?; + ivf.write(&mut writer).await?; writer.finish().await?; // Write the aux file @@ -1617,7 +1558,7 @@ async fn do_train_ivf_model( dimension: usize, metric_type: MetricType, params: &IvfBuildParams, -) -> Result +) -> Result where T::Native: Dot + L2 + Normalize, { @@ -1634,7 +1575,7 @@ where params.sample_rate, ) .await?; - Ok(Ivf::new(FixedSizeListArray::try_new_from_values( + Ok(IvfModel::new(FixedSizeListArray::try_new_from_values( centroids, dimension as i32, )?)) @@ -1645,7 +1586,7 @@ async fn train_ivf_model( data: &FixedSizeListArray, distance_type: DistanceType, params: &IvfBuildParams, -) -> Result { +) -> Result { assert!( distance_type != DistanceType::Cosine, "Cosine metric should be done by normalized L2 distance", @@ -1684,11 +1625,12 @@ mod tests { use std::ops::Range; use arrow_array::types::UInt64Type; - use arrow_array::{RecordBatchIterator, RecordBatchReader, UInt64Array}; + use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array}; use arrow_schema::Field; use itertools::Itertools; use lance_core::utils::address::RowAddress; use lance_core::ROW_ID; + use lance_index::vector::sq::builder::SQBuildParams; use lance_linalg::distance::l2_distance_batch; use lance_testing::datagen::{ generate_random_array, generate_random_array_with_range, generate_random_array_with_seed, @@ -2214,13 +2156,14 @@ mod tests { let ivf_model = build_ivf_model(&dataset, "vector", DIM, MetricType::L2, &ivf_params) .await .unwrap(); - assert_eq!(2, ivf_model.centroids.len()); - assert_eq!(32, ivf_model.centroids.value_length()); + assert_eq!(2, ivf_model.centroids.as_ref().unwrap().len()); + assert_eq!(32, ivf_model.centroids.as_ref().unwrap().value_length()); assert_eq!(2, ivf_model.num_partitions()); // All centroids values should be in the range [1000, 1100] ivf_model .centroids + .unwrap() .values() .as_primitive::() .values() @@ -2241,13 +2184,14 @@ mod tests { let ivf_model = build_ivf_model(&dataset, "vector", DIM, MetricType::Cosine, &ivf_params) .await .unwrap(); - assert_eq!(2, ivf_model.centroids.len()); - assert_eq!(32, ivf_model.centroids.value_length()); + assert_eq!(2, ivf_model.centroids.as_ref().unwrap().len()); + assert_eq!(32, ivf_model.centroids.as_ref().unwrap().value_length()); assert_eq!(2, ivf_model.num_partitions()); // All centroids values should be in the range [1000, 1100] ivf_model .centroids + .unwrap() .values() .as_primitive::() .values() @@ -2574,91 +2518,6 @@ mod tests { ); } - async fn test_create_ivf_hnsw_sq(distance_type: DistanceType) { - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - - let nlist = 4; - let (mut dataset, vector_array) = generate_test_dataset(test_uri, 0.0..1.0).await; - - let ivf_params = IvfBuildParams::new(nlist); - let sq_params = SQBuildParams::default(); - let hnsw_params = HnswBuildParams::default(); - let params = VectorIndexParams::with_ivf_hnsw_sq_params( - distance_type, - ivf_params, - hnsw_params, - sq_params, - ); - - dataset - .create_index(&["vector"], IndexType::Vector, None, ¶ms, false) - .await - .unwrap(); - - let mat = MatrixView::::try_from(vector_array.as_ref()).unwrap(); - let query = vector_array.value(0); - let query = query.as_primitive::(); - let k = 100; - let results = dataset - .scan() - .with_row_id() - .nearest("vector", query, k) - .unwrap() - .nprobs(nlist) - .try_into_stream() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - assert_eq!(1, results.len()); - assert_eq!(k, results[0].num_rows()); - - let row_ids = results[0] - .column_by_name(ROW_ID) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .map(|v| v.unwrap() as u32) - .collect::>(); - let dists = results[0] - .column_by_name("_distance") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .values() - .to_vec(); - - let results = dists.into_iter().zip(row_ids.into_iter()).collect_vec(); - let gt = ground_truth(&mat, query.values(), k, distance_type); - - let results_set = results.iter().map(|r| r.1).collect::>(); - let gt_set = gt.iter().map(|r| r.1).collect::>(); - - let recall = results_set.intersection(>_set).count() as f32 / k as f32; - assert!( - recall >= 0.9, - "recall: {}\n results: {:?}\n\ngt: {:?}", - recall, - results, - gt, - ); - } - - #[tokio::test] - async fn test_create_ivf_hnsw_sq_cosine() { - test_create_ivf_hnsw_sq(DistanceType::Cosine).await - } - - #[tokio::test] - async fn test_create_ivf_hnsw_sq_dot() { - test_create_ivf_hnsw_sq(DistanceType::Dot).await - } - #[tokio::test] async fn test_create_ivf_hnsw_with_empty_partition() { let test_dir = tempdir().unwrap(); @@ -2778,6 +2637,8 @@ mod tests { assert!(ivf_idx .ivf .centroids + .as_ref() + .unwrap() .values() .as_primitive::() .values() diff --git a/rust/lance/src/index/vector/ivf/builder.rs b/rust/lance/src/index/vector/ivf/builder.rs index d50f344481..5b06d60619 100644 --- a/rust/lance/src/index/vector/ivf/builder.rs +++ b/rust/lance/src/index/vector/ivf/builder.rs @@ -6,6 +6,7 @@ use std::ops::Range; use std::sync::Arc; use lance_file::writer::FileWriter; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::quantizer::Quantizer; use lance_table::io::manifest::ManifestDescribing; use object_store::path::Path; @@ -15,13 +16,13 @@ use tracing::instrument; use lance_core::{traits::DatasetTakeRows, Error, Result, ROW_ID}; use lance_index::vector::{ hnsw::{builder::HnswBuildParams, HnswMetadata}, - ivf::{shuffler::shuffle_dataset, storage::IvfData}, + ivf::shuffler::shuffle_dataset, pq::ProductQuantizer, }; use lance_io::{stream::RecordBatchStream, traits::Writer}; use lance_linalg::distance::MetricType; -use crate::index::vector::ivf::{io::write_pq_partitions, Ivf}; +use crate::index::vector::ivf::io::write_pq_partitions; use super::io::write_hnsw_quantization_index_partitions; @@ -34,7 +35,7 @@ pub(super) async fn build_partitions( writer: &mut dyn Writer, data: impl RecordBatchStream + Unpin + 'static, column: &str, - ivf: &mut Ivf, + ivf: &mut IvfModel, pq: Arc, metric_type: MetricType, part_range: Range, @@ -57,8 +58,8 @@ pub(super) async fn build_partitions( }); } - let ivf_model = lance_index::vector::ivf::Ivf::with_pq( - ivf.centroids.clone(), + let ivf_transformer = lance_index::vector::ivf::IvfTransformer::with_pq( + ivf.centroids.clone().unwrap(), metric_type, column, pq.clone(), @@ -68,7 +69,7 @@ pub(super) async fn build_partitions( let stream = shuffle_dataset( data, column, - ivf_model.into(), + ivf_transformer.into(), precomputed_partitons, ivf.num_partitions() as u32, shuffle_partition_batches, @@ -93,7 +94,7 @@ pub(super) async fn build_hnsw_partitions( auxiliary_writer: Option<&mut FileWriter>, data: impl RecordBatchStream + Unpin + 'static, column: &str, - ivf: &mut Ivf, + ivf: &mut IvfModel, quantizer: Quantizer, metric_type: MetricType, hnsw_params: &HnswBuildParams, @@ -102,7 +103,7 @@ pub(super) async fn build_hnsw_partitions( shuffle_partition_batches: usize, shuffle_partition_concurrency: usize, precomputed_shuffle_buffers: Option<(Path, Vec)>, -) -> Result<(Vec, IvfData)> { +) -> Result<(Vec, IvfModel)> { let schema = data.schema(); if schema.column_with_name(column).is_none() { return Err(Error::Schema { @@ -117,8 +118,8 @@ pub(super) async fn build_hnsw_partitions( }); } - let ivf_model = lance_index::vector::ivf::new_ivf_with_quantizer( - ivf.centroids.clone(), + let ivf_model = lance_index::vector::ivf::new_ivf_transformer_with_quantizer( + ivf.centroids.clone().unwrap(), metric_type, column, quantizer.clone(), diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index f65bcca6ef..0e7e4ebd7e 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -21,16 +21,12 @@ use lance_core::Error; use lance_file::reader::FileReader; use lance_file::writer::FileWriter; use lance_index::scalar::IndexWriter; -use lance_index::vector::hnsw::builder::HNSW_METADATA_KEY; +use lance_index::vector::hnsw::HNSW; use lance_index::vector::hnsw::{builder::HnswBuildParams, HnswMetadata}; -use lance_index::vector::ivf::storage::IvfData; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::ProductQuantizer; +use lance_index::vector::quantizer::{Quantization, Quantizer}; use lance_index::vector::v3::subindex::IvfSubIndex; -use lance_index::vector::{ - quantizer::{Quantization, Quantizer}, - sq::ScalarQuantizer, - storage::VectorStore, -}; use lance_index::vector::{PART_ID_COLUMN, PQ_CODE_COLUMN}; use lance_io::encodings::plain::PlainEncoder; use lance_io::object_store::ObjectStore; @@ -45,10 +41,10 @@ use snafu::{location, Location}; use tempfile::TempDir; use tokio::sync::Semaphore; -use super::{IVFIndex, Ivf}; +use super::IVFIndex; use crate::dataset::ROW_ID; use crate::index::vector::pq::{build_pq_storage, PQIndex}; -use crate::index::vector::sq::build_sq_storage; + use crate::Result; // TODO: make it configurable, limit by the number of CPU cores & memory @@ -147,7 +143,7 @@ async fn merge_streams( /// TODO: migrate this function to `lance-index` crate. pub(super) async fn write_pq_partitions( writer: &mut dyn Writer, - ivf: &mut Ivf, + ivf: &mut IvfModel, streams: Option>>>, existing_indices: Option<&[&IVFIndex]>, ) -> Result<()> { @@ -228,7 +224,7 @@ pub(super) async fn write_pq_partitions( .await?; let total_records = row_id_array.iter().map(|a| a.len()).sum::(); - ivf.add_partition(writer.tell().await?, total_records as u32); + ivf.add_partition_with_offset(writer.tell().await?, total_records as u32); if total_records > 0 { let pq_refs = pq_array.iter().map(|a| a.as_ref()).collect::>(); PlainEncoder::write(writer, &pq_refs).await?; @@ -253,11 +249,11 @@ pub(super) async fn write_hnsw_quantization_index_partitions( hnsw_params: &HnswBuildParams, writer: &mut FileWriter, mut auxiliary_writer: Option<&mut FileWriter>, - ivf: &mut Ivf, + ivf: &mut IvfModel, quantizer: Quantizer, streams: Option>>>, existing_indices: Option<&[&IVFIndex]>, -) -> Result<(Vec, IvfData)> { +) -> Result<(Vec, IvfModel)> { let hnsw_params = Arc::new(hnsw_params.clone()); let mut streams_heap = BinaryHeap::new(); @@ -384,14 +380,14 @@ pub(super) async fn write_hnsw_quantization_index_partitions( })); } - let mut aux_ivf = IvfData::empty(); + let mut aux_ivf = IvfModel::empty(); let mut hnsw_metadata = Vec::with_capacity(ivf.num_partitions()); for (part_id, task) in tasks.into_iter().enumerate() { let offset = writer.len(); let num_rows = task.await??; if num_rows == 0 { - ivf.add_partition(offset, 0); + ivf.add_partition(0); aux_ivf.add_partition(0); hnsw_metadata.push(HnswMetadata::default()); continue; @@ -414,9 +410,9 @@ pub(super) async fn write_hnsw_quantization_index_partitions( .await?; writer.write(&batches).await?; - ivf.add_partition(offset, (writer.len() - offset) as u32); + ivf.add_partition((writer.len() - offset) as u32); hnsw_metadata.push(serde_json::from_str( - part_reader.schema().metadata[HNSW_METADATA_KEY].as_str(), + part_reader.schema().metadata[HNSW::metadata_key()].as_str(), )?); std::mem::drop(part_reader); object_store.delete(part_file).await?; @@ -480,7 +476,8 @@ async fn build_hnsw_quantization_partition( metric_type = MetricType::L2; } - let build_hnsw = build_and_write_hnsw((*hnsw_params).clone(), vectors.clone(), writer); + let build_hnsw = + build_and_write_hnsw(vectors.clone(), (*hnsw_params).clone(), metric_type, writer); let build_store = match quantizer { Quantizer::Flat(_) => { @@ -497,13 +494,7 @@ async fn build_hnsw_quantization_partition( aux_writer.unwrap(), )), - Quantizer::Scalar(sq) => tokio::spawn(build_and_write_sq_storage( - metric_type, - row_ids, - vectors, - sq, - aux_writer.unwrap(), - )), + _ => unreachable!("IVF_HNSW_SQ has been moved to v2 index builder"), }; let index_rows = futures::join!(build_hnsw, build_store).0?; @@ -517,11 +508,12 @@ async fn build_hnsw_quantization_partition( } async fn build_and_write_hnsw( - params: HnswBuildParams, vectors: Arc, + params: HnswBuildParams, + distance_type: DistanceType, mut writer: FileWriter, ) -> Result { - let batch = params.build(vectors).await?.to_batch()?; + let batch = params.build(vectors, distance_type).await?.to_batch()?; let metadata = batch.schema_ref().metadata().clone(); writer.write_record_batch(batch).await?; writer.finish_with_metadata(&metadata).await @@ -545,26 +537,6 @@ async fn build_and_write_pq_storage( Ok(()) } -async fn build_and_write_sq_storage( - distance_type: DistanceType, - row_ids: Arc, - vectors: Arc, - sq: ScalarQuantizer, - mut writer: FileWriter, -) -> Result<()> { - let storage = spawn_cpu(move || { - let storage = build_sq_storage(distance_type, row_ids, vectors, sq)?; - Ok(storage) - }) - .await?; - - for batch in storage.to_batches()? { - writer.write_record_batch(batch.clone()).await?; - } - writer.finish().await?; - Ok(()) -} - #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 3b5ef10bb3..d88ebf6d60 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -4,6 +4,7 @@ //! IVF - Inverted File index. use core::fmt; +use std::marker::PhantomData; use std::{ any::Any, collections::HashMap, @@ -19,9 +20,16 @@ use arrow_array::{RecordBatch, StructArray, UInt32Array}; use async_trait::async_trait; use deepsize::DeepSizeOf; use futures::prelude::stream::{self, StreamExt, TryStreamExt}; +use lance_arrow::RecordBatchExt; use lance_core::{cache::DEFAULT_INDEX_CACHE_SIZE, Error, Result}; use lance_encoding::decoder::{DecoderMiddlewareChain, FilterExpression}; use lance_file::v2::reader::FileReader; +use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::hnsw::HNSW; +use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::quantizer::{QuantizationType, Quantizer}; +use lance_index::vector::sq::ScalarQuantizer; +use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ pb, vector::{ @@ -30,6 +38,7 @@ use lance_index::{ }, Index, IndexType, INDEX_AUXILIARY_FILE_NAME, INDEX_FILE_NAME, }; +use lance_index::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; use lance_io::{ object_store::ObjectStore, scheduler::ScanScheduler, traits::Reader, ReadBatchParams, }; @@ -46,20 +55,28 @@ use crate::{ session::Session, }; -use super::{centroids_to_vectors, Ivf, IvfIndexPartitionStatistics, IvfIndexStatistics}; +use super::{centroids_to_vectors, IvfIndexPartitionStatistics, IvfIndexStatistics}; + +#[derive(Debug)] +struct PartitionEntry { + index: S, + storage: Q::Storage, +} + /// IVF Index. #[derive(Debug)] -pub struct IVFIndex { +pub struct IVFIndex { uuid: String, /// Ivf model - ivf: Ivf, + ivf: IvfModel, reader: FileReader, - storage: IvfQuantizationStorage, + sub_index_metadata: Vec, + storage: IvfQuantizationStorage, /// Index in each partition. - sub_index_cache: Cache>, + partition_cache: Cache>>, distance_type: DistanceType, @@ -68,16 +85,18 @@ pub struct IVFIndex { /// The session cache, used when fetching pages #[allow(dead_code)] session: Weak, + + _marker: PhantomData, } -impl DeepSizeOf for IVFIndex { +impl DeepSizeOf for IVFIndex { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.uuid.deep_size_of_children(context) + self.storage.deep_size_of_children(context) // Skipping session since it is a weak ref } } -impl IVFIndex { +impl IVFIndex { /// Create a new IVF index. pub(crate) async fn try_new( object_store: Arc, @@ -95,17 +114,18 @@ impl IVFIndex { DecoderMiddlewareChain::default(), ) .await?; - let distance_type = DistanceType::try_from( + let index_metadata: IndexMetadata = serde_json::from_str( index_reader .schema() .metadata - .get(DISTANCE_TYPE_KEY) + .get(INDEX_METADATA_SCHEMA_KEY) .ok_or(Error::Index { message: format!("{} not found", DISTANCE_TYPE_KEY), location: location!(), })? .as_str(), )?; + let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?; let ivf_pos = index_reader .schema() @@ -121,7 +141,17 @@ impl IVFIndex { location: location!(), })?; let ivf_pb_bytes = index_reader.read_global_buffer(ivf_pos).await?; - let ivf = Ivf::try_from(&pb::Ivf::decode(ivf_pb_bytes)?)?; + let ivf = IvfModel::try_from(pb::Ivf::decode(ivf_pb_bytes)?)?; + + let sub_index_metadata = index_reader + .schema() + .metadata + .get(S::metadata_key()) + .ok_or(Error::Index { + message: format!("{} not found", S::metadata_key()), + location: location!(), + })?; + let sub_index_metadata: Vec = serde_json::from_str(sub_index_metadata)?; let storage_reader = FileReader::try_open( scheduler @@ -135,81 +165,95 @@ impl IVFIndex { DecoderMiddlewareChain::default(), ) .await?; - let storage = IvfQuantizationStorage::open(storage_reader).await?; + let storage = IvfQuantizationStorage::try_new(storage_reader).await?; Ok(Self { uuid, ivf, reader: index_reader, storage, - sub_index_cache: Cache::new(DEFAULT_INDEX_CACHE_SIZE as u64), + partition_cache: Cache::new(DEFAULT_INDEX_CACHE_SIZE as u64), + sub_index_metadata, distance_type, session, + _marker: PhantomData, }) } #[instrument(level = "debug", skip(self))] - pub async fn load_partition(&self, partition_id: usize, write_cache: bool) -> Result> { + pub async fn load_partition( + &self, + partition_id: usize, + write_cache: bool, + ) -> Result>> { let cache_key = format!("{}-ivf-{}", self.uuid, partition_id); - let part_index = if let Some(part_idx) = self.sub_index_cache.get(&cache_key) { + let part_entry = if let Some(part_idx) = self.partition_cache.get(&cache_key) { part_idx } else { - if partition_id >= self.ivf.lengths.len() { + if partition_id >= self.ivf.num_partitions() { return Err(Error::Index { message: format!( "partition id {} is out of range of {} partitions", partition_id, - self.ivf.lengths.len() + self.ivf.num_partitions() ), location: location!(), }); } - let offset = self.ivf.offsets[partition_id]; - let length = self.ivf.lengths[partition_id] as usize; - let batches = self - .reader - .read_stream( - ReadBatchParams::Range(offset..offset + length), - 4096, - 16, - FilterExpression::no_filter(), - )? - .peekable() - .try_collect::>() - .await?; let schema = Arc::new(self.reader.schema().as_ref().into()); - let batch = concat_batches(&schema, batches.iter())?; - let idx = Arc::new(I::load(batch)?); + let batch = match self.reader.metadata().num_rows { + 0 => RecordBatch::new_empty(schema), + _ => { + let batches = self + .reader + .read_stream( + ReadBatchParams::Range(self.ivf.row_range(partition_id)), + u32::MAX, + 1, + FilterExpression::no_filter(), + )? + .try_collect::>() + .await?; + concat_batches(&schema, batches.iter())? + } + }; + let batch = batch.add_metadata( + S::metadata_key().to_owned(), + self.sub_index_metadata[partition_id].clone(), + )?; + let idx = S::load(batch)?; + let storage = self.load_partition_storage(partition_id).await?; + let partition_entry = Arc::new(PartitionEntry { + index: idx, + storage, + }); if write_cache { - self.sub_index_cache.insert(cache_key.clone(), idx.clone()); + self.partition_cache + .insert(cache_key.clone(), partition_entry.clone()); } - idx + partition_entry }; - Ok(part_index) - } - async fn search_in_partition( - &self, - partition_id: usize, - query: &Query, - pre_filter: Arc, - ) -> Result { - let part_index = self.load_partition(partition_id, true).await?; + Ok(part_entry) + } - let query = self.preprocess_query(partition_id, query)?; - let storage = self.storage.load_partition(partition_id).await?; - let param = (&query).into(); - pre_filter.wait_for_ready().await?; - part_index.search(query.key, query.k, param, &storage, pre_filter) + pub async fn load_partition_storage(&self, partition_id: usize) -> Result { + self.storage.load_partition::(partition_id).await } /// preprocess the query vector given the partition id. /// /// Internal API with no stability guarantees. pub fn preprocess_query(&self, partition_id: usize, query: &Query) -> Result { - if I::use_residual() { - let partition_centroids = self.ivf.centroids.value(partition_id); + if S::use_residual() { + let partition_centroids = + self.ivf + .centroid(partition_id) + .ok_or_else(|| Error::Index { + message: format!("partition centroid {} does not exist", partition_id), + location: location!(), + })?; let residual_key = sub(&query.key, &partition_centroids)?; let mut part_query = query.clone(); part_query.key = residual_key; @@ -218,20 +262,10 @@ impl IVFIndex { Ok(query.clone()) } } - - pub fn find_partitions(&self, query: &Query) -> Result { - let dt = if self.distance_type == DistanceType::Cosine { - DistanceType::L2 - } else { - self.distance_type - }; - - self.ivf.find_partitions(&query.key, query.nprobes, dt) - } } #[async_trait] -impl Index for IVFIndex { +impl Index for IVFIndex { fn as_any(&self) -> &dyn Any { self } @@ -240,19 +274,22 @@ impl Index for IVFIndex) -> Result> { + Ok(self) + } + fn index_type(&self) -> IndexType { IndexType::Vector } fn statistics(&self) -> Result { - let partitions_statistics = self - .ivf - .lengths - .iter() - .map(|&len| IvfIndexPartitionStatistics { size: len }) + let partitions_statistics = (0..self.ivf.num_partitions()) + .map(|part_id| IvfIndexPartitionStatistics { + size: self.ivf.partition_size(part_id) as u32, + }) .collect::>(); - let centroid_vecs = centroids_to_vectors(&self.ivf.centroids)?; + let centroid_vecs = centroids_to_vectors(self.ivf.centroids.as_ref().unwrap())?; Ok(serde_json::to_value(IvfIndexStatistics { index_type: "IVF".to_string(), @@ -274,10 +311,11 @@ impl Index for IVFIndex VectorIndex - for IVFIndex +impl VectorIndex + for IVFIndex { async fn search(&self, query: &Query, pre_filter: Arc) -> Result { + pre_filter.wait_for_ready().await?; let mut query = query.clone(); if self.distance_type == DistanceType::Cosine { let key = normalize_arrow(&query.key)?; @@ -312,6 +350,45 @@ impl Result { + let dt = if self.distance_type == DistanceType::Cosine { + DistanceType::L2 + } else { + self.distance_type + }; + + self.ivf.find_partitions(&query.key, query.nprobes, dt) + } + + // async fn append(&self, batches: Vec) -> Result<()> { + // IvfIndexBuilder::new( + // dataset, + // column, + // index_dir, + // distance_type, + // shuffler, + // ivf_params, + // sub_index_params, + // quantizer_params, + // ) + // } + + async fn search_in_partition( + &self, + partition_id: usize, + query: &Query, + pre_filter: Arc, + ) -> Result { + let part_entry = self.load_partition(partition_id, true).await?; + + let query = self.preprocess_query(partition_id, query)?; + let param = (&query).into(); + // pre_filter.wait_for_ready().await?; + part_entry + .index + .search(query.key, query.k, param, &part_entry.storage, pre_filter) + } + fn is_loadable(&self) -> bool { false } @@ -353,21 +430,44 @@ impl IvfModel { + self.ivf.clone() + } + + fn quantizer(&self) -> Quantizer { + self.storage.quantizer::().unwrap() + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + (S::name().try_into().unwrap(), Q::quantization_type()) + } + fn metric_type(&self) -> DistanceType { self.distance_type } } +pub type IvfFlatIndex = IVFIndex; +pub type IvfHnswSqIndex = IVFIndex; + #[cfg(test)] mod tests { + use std::collections::HashSet; use std::{collections::HashMap, ops::Range, sync::Arc}; + use arrow::datatypes::UInt64Type; use arrow::{array::AsArray, datatypes::Float32Type}; use arrow_array::{Array, FixedSizeListArray, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use lance_arrow::FixedSizeListArrayExt; - use lance_index::DatasetIndexExt; + use lance_core::ROW_ID; + use lance_index::vector::hnsw::builder::HnswBuildParams; + use lance_index::vector::ivf::IvfBuildParams; + use lance_index::vector::sq::builder::SQBuildParams; + use lance_index::vector::DIST_COL; + use lance_index::{DatasetIndexExt, IndexType}; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_range; use tempfile::tempdir; @@ -395,7 +495,9 @@ mod tests { )]) .with_metadata(metadata) .into(); - let array = Arc::new(FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap()); + let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + let fsl = lance_linalg::kernels::normalize_fsl(&fsl).unwrap(); + let array = Arc::new(fsl); let batch = RecordBatch::try_new(schema.clone(), vec![array.clone()]).unwrap(); let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone()); @@ -427,7 +529,7 @@ mod tests { async fn test_build_ivf_flat() { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let (mut dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; + let (mut dataset, vectors) = generate_test_dataset(test_uri, 0.0..1.0).await; let nlist = 16; let params = VectorIndexParams::ivf_flat(nlist, DistanceType::L2); @@ -442,52 +544,123 @@ mod tests { .await .unwrap(); - // TODO: test query after we replace the IVFIndex with the new one - // let query = vectors.value(0); - // let k = 100; - // let result = dataset - // .scan() - // .nearest("vector", query.as_primitive::(), k) - // .unwrap() - // .nprobs(nlist) - // .with_row_id() - // .try_into_batch() - // .await - // .unwrap(); - - // let row_ids = result - // .column_by_name(ROW_ID) - // .unwrap() - // .as_primitive::() - // .values() - // .to_vec(); - // let dists = result - // .column_by_name("_distance") - // .unwrap() - // .as_primitive::() - // .values() - // .to_vec(); - // let results = dists - // .into_iter() - // .zip(row_ids.into_iter()) - // .collect::>(); - // let row_ids = results.iter().map(|(_, id)| *id).collect::>(); - - // let gt = ground_truth( - // &vectors, - // query.as_primitive::().values(), - // k, - // DistanceType::L2, - // ); - // let gt_set = gt.iter().map(|r| r.1).collect::>(); - - // let recall = row_ids.intersection(>_set).count() as f32 / k as f32; - // assert!( - // recall >= 1.0, - // "recall: {}\n results: {:?}\n\ngt: {:?}", - // recall, - // results, - // gt, - // ); + let query = vectors.value(0); + let k = 100; + let result = dataset + .scan() + .nearest("vector", query.as_primitive::(), k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + + let row_ids = result[ROW_ID] + .as_primitive::() + .values() + .to_vec(); + let dists = result[DIST_COL] + .as_primitive::() + .values() + .to_vec(); + let results = dists + .into_iter() + .zip(row_ids.into_iter()) + .collect::>(); + let row_ids = results.iter().map(|(_, id)| *id).collect::>(); + + let gt = ground_truth( + &vectors, + query.as_primitive::().values(), + k, + DistanceType::L2, + ); + let gt_set = gt.iter().map(|r| r.1).collect::>(); + + let recall = row_ids.intersection(>_set).count() as f32 / k as f32; + assert!( + recall >= 1.0, + "recall: {}\n results: {:?}\n\ngt: {:?}", + recall, + results, + gt, + ); + } + + async fn test_create_ivf_hnsw_sq(distance_type: DistanceType) { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let nlist = 4; + let (mut dataset, vectors) = generate_test_dataset(test_uri, 0.0..1.0).await; + + let ivf_params = IvfBuildParams::new(nlist); + let sq_params = SQBuildParams::default(); + let hnsw_params = HnswBuildParams::default(); + let params = VectorIndexParams::with_ivf_hnsw_sq_params( + distance_type, + ivf_params, + hnsw_params, + sq_params, + ); + + dataset + .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + let query = vectors.value(0); + let k = 100; + let result = dataset + .scan() + .nearest("vector", query.as_primitive::(), k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + + let row_ids = result[ROW_ID] + .as_primitive::() + .values() + .to_vec(); + let dists = result[DIST_COL] + .as_primitive::() + .values() + .to_vec(); + let results = dists + .into_iter() + .zip(row_ids.into_iter()) + .collect::>(); + let row_ids = results.iter().map(|(_, id)| *id).collect::>(); + + let gt = ground_truth( + &vectors, + query.as_primitive::().values(), + k, + distance_type, + ); + let gt_set = gt.iter().map(|r| r.1).collect::>(); + + let recall = row_ids.intersection(>_set).count() as f32 / k as f32; + assert!( + recall >= 0.9, + "recall: {}\n results: {:?}\n\ngt: {:?}", + recall, + results, + gt, + ); + } + + #[tokio::test] + async fn test_create_ivf_hnsw_sq_cosine() { + test_create_ivf_hnsw_sq(DistanceType::Cosine).await + } + + #[tokio::test] + async fn test_create_ivf_hnsw_sq_dot() { + test_create_ivf_hnsw_sq(DistanceType::Dot).await } } diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 2980a522be..81ece189f0 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -6,6 +6,7 @@ use std::{any::Any, collections::HashMap}; use arrow::compute::concat; use arrow_array::types::{Float16Type, Float32Type, Float64Type}; +use arrow_array::UInt32Array; use arrow_array::{ cast::{as_primitive_array, AsArray}, Array, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array, @@ -18,8 +19,10 @@ use deepsize::DeepSizeOf; use lance_core::utils::tokio::spawn_cpu; use lance_core::ROW_ID; use lance_core::{utils::address::RowAddress, ROW_ID_FIELD}; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::storage::ProductQuantizationStorage; -use lance_index::vector::quantizer::Quantization; +use lance_index::vector::quantizer::{Quantization, QuantizationType, Quantizer}; +use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ vector::{pq::ProductQuantizer, Query, DIST_COL}, Index, IndexType, @@ -36,7 +39,6 @@ use tracing::{instrument, span, Level}; pub use lance_index::vector::pq::{PQBuildParams, ProductQuantizerImpl}; use lance_linalg::kernels::normalize_fsl; -use super::ivf::Ivf; use super::VectorIndex; use crate::index::prefilter::PreFilter; use crate::index::vector::utils::maybe_sample_training_data; @@ -132,6 +134,10 @@ impl Index for PQIndex { self } + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + fn index_type(&self) -> IndexType { IndexType::Vector } @@ -211,6 +217,19 @@ impl VectorIndex for PQIndex { .await } + fn find_partitions(&self, _: &Query) -> Result { + unimplemented!("only for IVF") + } + + async fn search_in_partition( + &self, + _: usize, + _: &Query, + _: Arc, + ) -> Result { + unimplemented!("only for IVF") + } + fn is_loadable(&self) -> bool { true } @@ -289,6 +308,18 @@ impl VectorIndex for PQIndex { Ok(()) } + fn ivf_model(&self) -> IvfModel { + unimplemented!("only for IVF") + } + fn quantizer(&self) -> Quantizer { + unimplemented!("only for IVF") + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + (SubIndexType::Flat, QuantizationType::Product) + } + fn metric_type(&self) -> MetricType { self.metric_type } @@ -309,7 +340,7 @@ pub(super) async fn build_pq_model( dim: usize, metric_type: MetricType, params: &PQBuildParams, - ivf: Option<&Ivf>, + ivf: Option<&IvfModel>, ) -> Result> { if let Some(codebook) = ¶ms.codebook { let mt = if metric_type == MetricType::Cosine { @@ -381,7 +412,11 @@ pub(super) async fn build_pq_model( // Compute residual for PQ training. // // TODO: consolidate IVF models to `lance_index`. - let ivf2 = lance_index::vector::ivf::new_ivf(ivf.centroids.clone(), MetricType::L2, vec![]); + let ivf2 = lance_index::vector::ivf::new_ivf_transformer( + ivf.centroids.clone().unwrap(), + MetricType::L2, + vec![], + ); span!(Level::INFO, "compute residual for PQ training") .in_scope(|| ivf2.compute_residual(&training_data))? } else { @@ -470,7 +505,7 @@ mod tests { let centroids = generate_random_array_with_range(4 * DIM, -1.0..1.0); let fsl = FixedSizeListArray::try_new_from_values(centroids, DIM as i32).unwrap(); - let ivf = Ivf::new(fsl); + let ivf = IvfModel::new(fsl); let params = PQBuildParams::new(16, 8); let pq = build_pq_model(&dataset, "vector", DIM, MetricType::L2, ¶ms, Some(&ivf)) .await @@ -533,7 +568,11 @@ mod tests { let vectors = normalize_fsl(&vectors).unwrap(); let row = vectors.slice(0, 1); - let ivf2 = lance_index::vector::ivf::new_ivf(ivf.centroids.clone(), MetricType::L2, vec![]); + let ivf2 = lance_index::vector::ivf::new_ivf_transformer( + ivf.centroids.clone().unwrap(), + MetricType::L2, + vec![], + ); let residual_query = ivf2.compute_residual(&row).unwrap(); let pq_code = pq.transform(&residual_query).unwrap(); diff --git a/rust/lance/src/index/vector/sq.rs b/rust/lance/src/index/vector/sq.rs deleted file mode 100644 index 0542255f20..0000000000 --- a/rust/lance/src/index/vector/sq.rs +++ /dev/null @@ -1,75 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -use std::sync::Arc; - -use arrow::datatypes::Float32Type; -use arrow_array::{Array, RecordBatch}; -use lance_core::{Result, ROW_ID}; -use lance_index::vector::{ - quantizer::Quantization, - sq::{builder::SQBuildParams, storage::ScalarQuantizationStorage, ScalarQuantizer}, -}; -use lance_linalg::{distance::MetricType, kernels::normalize_fsl}; - -use crate::{index::vector::utils::maybe_sample_training_data, Dataset}; - -pub(super) async fn build_sq_model( - dataset: &Dataset, - column: &str, - metric_type: MetricType, - params: &SQBuildParams, -) -> Result { - log::info!("Start to train SQ code: SQ{}", params.num_bits); - let expected_sample_size = 2usize.pow(params.num_bits as u32) * params.sample_rate; - log::info!( - "Loading training data for SQ. Sample size: {}", - expected_sample_size - ); - let start = std::time::Instant::now(); - let mut training_data = - maybe_sample_training_data(dataset, column, expected_sample_size).await?; - log::info!( - "Finished loading training data in {:02} seconds", - start.elapsed().as_secs_f32() - ); - - log::info!( - "starting to compute partitions for SQ training, sample size: {}", - training_data.value_length() - ); - - if metric_type == MetricType::Cosine { - log::info!("Normalize training data for SQ training: Cosine"); - training_data = normalize_fsl(&training_data)?; - } - - log::info!("Start train SQ: params={:#?}", params); - let sq = params.build(&training_data, MetricType::L2)?; - log::info!( - "Trained SQ{}[{:?}] in: {} seconds", - sq.num_bits(), - sq.bounds(), - start.elapsed().as_secs_f32() - ); - Ok(sq) -} - -pub fn build_sq_storage( - metric_type: MetricType, - row_ids: Arc, - vectors: Arc, - sq: ScalarQuantizer, -) -> Result { - let code_column = sq.transform::(vectors.as_ref())?; - std::mem::drop(vectors); - - let batch = RecordBatch::try_from_iter_with_nullable(vec![ - (ROW_ID, row_ids, true), - (sq.column(), code_column, false), - ])?; - let store = - ScalarQuantizationStorage::try_new(sq.num_bits(), metric_type, sq.bounds(), [batch])?; - - Ok(store) -} diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 76afb79bee..797f4b71c5 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -28,9 +28,7 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; use lance_core::utils::mask::{RowIdMask, RowIdTreeMap}; use lance_core::{ROW_ID, ROW_ID_FIELD}; -use lance_index::vector::{ - flat::flat_search, Query, VectorIndex, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN, -}; +use lance_index::vector::{flat::flat_search, Query, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN}; use lance_io::stream::RecordBatchStream; use lance_linalg::distance::DistanceType; use lance_linalg::kernels::normalize_arrow; @@ -44,7 +42,6 @@ use tracing::{instrument, Instrument}; use crate::dataset::scanner::DatasetRecordBatchStream; use crate::dataset::Dataset; use crate::index::prefilter::{DatasetPreFilter, FilterLoader}; -use crate::index::vector::ivf::IVFIndex; use crate::index::DatasetIndexInternalExt; use crate::{Error, Result}; @@ -487,12 +484,7 @@ impl ExecutionPlan for ANNIvfPartitionExec { let ds = ds.clone(); async move { - let raw_index = ds.open_vector_index(&query.column, &uuid).await?; - let index = raw_index.as_any().downcast_ref::().ok_or( - DataFusionError::Execution( - "ANNIVFPartitionExec: index is not a IVF type".to_string(), - ), - )?; + let index = ds.open_vector_index(&query.column, &uuid).await?; let mut query = query.clone(); if index.metric_type() == DistanceType::Cosine { @@ -734,13 +726,7 @@ impl ExecutionPlan for ANNIvfSubIndexExec { .map(move |result| { let query = query.clone(); async move { - let (part_id, (raw_index, pre_filter)) = result?; - - let index = raw_index.as_any().downcast_ref::().ok_or( - DataFusionError::Execution( - "ANNSubIndexExec: sub-index is not a IVF type".to_string(), - ), - )?; + let (part_id, (index, pre_filter)) = result?; let mut query = query.clone(); if index.metric_type() == DistanceType::Cosine { diff --git a/rust/lance/src/session/index_extension.rs b/rust/lance/src/session/index_extension.rs index 20d26cca2e..4148d399ba 100644 --- a/rust/lance/src/session/index_extension.rs +++ b/rust/lance/src/session/index_extension.rs @@ -65,10 +65,13 @@ mod test { sync::{atomic::AtomicBool, Arc}, }; - use arrow_array::RecordBatch; + use arrow_array::{RecordBatch, UInt32Array}; use arrow_schema::Schema; use deepsize::DeepSizeOf; use lance_file::writer::{FileWriter, FileWriterOptions}; + use lance_index::vector::ivf::storage::IvfModel; + use lance_index::vector::quantizer::{QuantizationType, Quantizer}; + use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ vector::{hnsw::VECTOR_ID_FIELD, Query}, DatasetIndexExt, Index, IndexMetadata, IndexType, INDEX_FILE_NAME, @@ -100,6 +103,10 @@ mod test { self } + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + fn statistics(&self) -> Result { Ok(json!(())) } @@ -116,7 +123,20 @@ mod test { #[async_trait::async_trait] impl VectorIndex for MockIndex { async fn search(&self, _: &Query, _: Arc) -> Result { - todo!("panic") + unimplemented!() + } + + fn find_partitions(&self, _: &Query) -> Result { + unimplemented!() + } + + async fn search_in_partition( + &self, + _: usize, + _: &Query, + _: Arc, + ) -> Result { + unimplemented!() } fn is_loadable(&self) -> bool { @@ -137,17 +157,29 @@ mod test { _: usize, _: usize, ) -> Result> { - todo!("panic") + unimplemented!() } fn row_ids(&self) -> Box> { - todo!("panic") + unimplemented!() } fn remap(&mut self, _: &HashMap>) -> Result<()> { Ok(()) } + fn ivf_model(&self) -> IvfModel { + unimplemented!() + } + fn quantizer(&self) -> Quantizer { + unimplemented!() + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + unimplemented!() + } + fn metric_type(&self) -> MetricType { MetricType::L2 }