From cbd0d8e4f9c0321f46e3faffee5fdba1f0fd0f2c Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 21 Jun 2024 10:30:14 -0700 Subject: [PATCH 01/13] chore: make flat storage and hnsw build accept dot distance type (#2499) --- python/src/utils.rs | 7 ++++++- rust/lance-index/src/vector/hnsw/builder.rs | 15 ++++++++++----- rust/lance/src/index/vector/ivf.rs | 8 ++++---- rust/lance/src/index/vector/ivf/io.rs | 8 +++++--- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/python/src/utils.rs b/python/src/utils.rs index 9d6c0636ba..960870f0f7 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -153,12 +153,14 @@ impl Hnsw { max_level=7, m=20, ef_construction=100, + distance_type="l2", ))] fn build( vectors_array: &PyIterator, max_level: u16, m: usize, ef_construction: usize, + distance_type: &str, ) -> PyResult { let params = HnswBuildParams::default() .max_level(max_level) @@ -177,9 +179,12 @@ impl Hnsw { let vectors = concat(&array_refs).map_err(|e| PyIOError::new_err(e.to_string()))?; std::mem::drop(data); + let dt = DistanceType::try_from(distance_type) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + let hnsw = RT .runtime - .block_on(params.build(vectors.clone())) + .block_on(params.build(vectors.clone(), dt)) .map_err(|e| PyIOError::new_err(e.to_string()))?; Ok(Self { hnsw, vectors }) } diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index 072efe6b0a..bd664abf4b 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -103,11 +103,15 @@ impl HnswBuildParams { self } - pub async fn build(self, data: ArrayRef) -> Result { - // We have normalized the vectors if the metric type is cosine, so we can use the L2 distance + /// Build the HNSW index from the given data. + /// + /// # Parameters + /// - `data`: A FixedSizeList to build the HNSW. + /// - `distance_type`: The distance type to use. + pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result { let vec_store = Arc::new(FlatStorage::new( data.as_fixed_size_list().clone(), - DistanceType::L2, + distance_type, )); HNSW::index_vectors(vec_store.as_ref(), self) } @@ -648,11 +652,12 @@ impl IvfSubIndex for HNSW { }; log::info!( - "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}", + "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}", storage.len(), hnsw.inner.params.max_level, hnsw.inner.params.m, - hnsw.inner.params.ef_construction + hnsw.inner.params.ef_construction, + storage.distance_type(), ); let len = storage.len(); diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 573e6d3a2a..ffb032f1d4 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2574,7 +2574,7 @@ mod tests { ); } - async fn test_create_ivf_hnsw_sq(distance_type: DistanceType) { + async fn test_create_ivf_hnsw_sq(distance_type: DistanceType, expected_recall: f32) { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -2641,7 +2641,7 @@ mod tests { let recall = results_set.intersection(>_set).count() as f32 / k as f32; assert!( - recall >= 0.9, + recall >= expected_recall, "recall: {}\n results: {:?}\n\ngt: {:?}", recall, results, @@ -2651,12 +2651,12 @@ mod tests { #[tokio::test] async fn test_create_ivf_hnsw_sq_cosine() { - test_create_ivf_hnsw_sq(DistanceType::Cosine).await + test_create_ivf_hnsw_sq(DistanceType::Cosine, 0.9).await } #[tokio::test] async fn test_create_ivf_hnsw_sq_dot() { - test_create_ivf_hnsw_sq(DistanceType::Dot).await + test_create_ivf_hnsw_sq(DistanceType::Dot, 0.8).await } #[tokio::test] diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index f65bcca6ef..74a0a17fb1 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -480,7 +480,8 @@ async fn build_hnsw_quantization_partition( metric_type = MetricType::L2; } - let build_hnsw = build_and_write_hnsw((*hnsw_params).clone(), vectors.clone(), writer); + let build_hnsw = + build_and_write_hnsw(vectors.clone(), (*hnsw_params).clone(), metric_type, writer); let build_store = match quantizer { Quantizer::Flat(_) => { @@ -517,11 +518,12 @@ async fn build_hnsw_quantization_partition( } async fn build_and_write_hnsw( - params: HnswBuildParams, vectors: Arc, + params: HnswBuildParams, + distance_type: DistanceType, mut writer: FileWriter, ) -> Result { - let batch = params.build(vectors).await?.to_batch()?; + let batch = params.build(vectors, distance_type).await?.to_batch()?; let metadata = batch.schema_ref().metadata().clone(); writer.write_record_batch(batch).await?; writer.finish_with_metadata(&metadata).await From 735f5b2720226d225c4845e3e031dc0a7e3ebe85 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 21 Jun 2024 14:43:24 -0700 Subject: [PATCH 02/13] docs(python): note multiprocessing incompatibility (#2506) Fixes #2405 --- docs/integrations/pytorch.rst | 5 +++++ docs/integrations/tensorflow.rst | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/integrations/pytorch.rst b/docs/integrations/pytorch.rst index 0b49e2ac14..58d7285203 100644 --- a/docs/integrations/pytorch.rst +++ b/docs/integrations/pytorch.rst @@ -74,3 +74,8 @@ Available samplers: - :class:`lance.sampler.ShardedFragmentSampler` - :class:`lance.sampler.ShardedBatchSampler` + +.. warning:: + For multiprocessing you should probably not use fork as lance is + multi-threaded internally and fork and multi-thread do not work well. + Refer to `this discussion `_. diff --git a/docs/integrations/tensorflow.rst b/docs/integrations/tensorflow.rst index 37ed235316..c381abd363 100644 --- a/docs/integrations/tensorflow.rst +++ b/docs/integrations/tensorflow.rst @@ -88,4 +88,7 @@ workers. for batch in ds: print(batch) - +.. warning:: + For multiprocessing you should probably not use fork as lance is + multi-threaded internally and fork and multi-thread do not work well. + Refer to `this discussion `_. From 95b67a9f0d5efaf1062f6ae577f1b6329aac2b3d Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 21 Jun 2024 16:18:24 -0700 Subject: [PATCH 03/13] ci: fix pytorch CI failure in arm64 jobs (#2507) Fix pytorch CI failure on arm64 by changing to buildjet service --- .github/workflows/python.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 743170c619..df6e712cfa 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -110,7 +110,8 @@ jobs: run: sudo rm -rf target/wheels linux-arm: timeout-minutes: 45 - runs-on: warp-ubuntu-latest-arm64-4x + #runs-on: warp-ubuntu-latest-arm64-4x + runs-on: buildjet-4vcpu-ubuntu-2204-arm name: Python Linux 3.${{ matrix.python-minor-version }} ARM strategy: matrix: From 0bf765b219a0a8c440b259921a3bf3506fb89a75 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sat, 22 Jun 2024 00:52:29 -0700 Subject: [PATCH 04/13] fix: dot distance so kmeans can converge (#2509) BREAKING CHANGE: Fix dot distance calculation from `-|xy|` to `1 - |xy|`. * Add `log::info` to show loss after kmeans converge earlier. --- Cargo.toml | 32 +++++++++++++-------------- python/Cargo.toml | 2 +- rust/lance-linalg/src/distance/dot.rs | 4 ++-- rust/lance-linalg/src/kernels.rs | 1 + rust/lance-linalg/src/kmeans.rs | 10 ++++----- 5 files changed, 25 insertions(+), 24 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4274bec8d9..79efd71405 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.12.4" +version = "0.13.0" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -40,23 +40,23 @@ categories = [ "development-tools", "science", ] -rust-version = "1.75" +rust-version = "1.78" [workspace.dependencies] -lance = { version = "=0.12.4", path = "./rust/lance" } -lance-arrow = { version = "=0.12.4", path = "./rust/lance-arrow" } -lance-core = { version = "=0.12.4", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.12.4", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.12.4", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.12.4", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.12.4", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.12.4", path = "./rust/lance-file" } -lance-index = { version = "=0.12.4", path = "./rust/lance-index" } -lance-io = { version = "=0.12.4", path = "./rust/lance-io" } -lance-linalg = { version = "=0.12.4", path = "./rust/lance-linalg" } -lance-table = { version = "=0.12.4", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.12.4", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.12.4", path = "./rust/lance-testing" } +lance = { version = "=0.13.0", path = "./rust/lance" } +lance-arrow = { version = "=0.13.0", path = "./rust/lance-arrow" } +lance-core = { version = "=0.13.0", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.13.0", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.13.0", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.13.0", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.13.0", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.13.0", path = "./rust/lance-file" } +lance-index = { version = "=0.13.0", path = "./rust/lance-index" } +lance-io = { version = "=0.13.0", path = "./rust/lance-io" } +lance-linalg = { version = "=0.13.0", path = "./rust/lance-linalg" } +lance-table = { version = "=0.13.0", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.13.0", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.13.0", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "51.0.0", optional = false, features = ["prettyprint"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index f22c570c2f..b35c7c4ea8 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.12.4" +version = "0.13.0" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" diff --git a/rust/lance-linalg/src/distance/dot.rs b/rust/lance-linalg/src/distance/dot.rs index c84d43426d..226bfdae46 100644 --- a/rust/lance-linalg/src/distance/dot.rs +++ b/rust/lance-linalg/src/distance/dot.rs @@ -66,10 +66,10 @@ pub fn dot(from: &[T], to: &[T]) -> f32 { T::dot(from, to) } -/// Negative dot distance. +/// Negative [Dot] distance. #[inline] pub fn dot_distance(from: &[T], to: &[T]) -> f32 { - -T::dot(from, to) + 1.0 - T::dot(from, to) } /// Dot product diff --git a/rust/lance-linalg/src/kernels.rs b/rust/lance-linalg/src/kernels.rs index 1e9d3c20e1..a9aa41f89a 100644 --- a/rust/lance-linalg/src/kernels.rs +++ b/rust/lance-linalg/src/kernels.rs @@ -75,6 +75,7 @@ pub fn argmin_value( /// Returns the minimal value (float) and the index (argmin) from an Iterator. /// /// Return `None` if the iterator is empty or all are `Nan/Inf`. +#[inline] pub fn argmin_value_float(iter: impl Iterator) -> Option<(u32, T)> { let mut min_idx = None; let mut min_value = T::infinity(); diff --git a/rust/lance-linalg/src/kmeans.rs b/rust/lance-linalg/src/kmeans.rs index bb0e353638..85990dd3c6 100644 --- a/rust/lance-linalg/src/kmeans.rs +++ b/rust/lance-linalg/src/kmeans.rs @@ -30,7 +30,7 @@ use rayon::prelude::*; use crate::distance::hamming::hamming; use crate::distance::{dot_distance_batch, DistanceType}; -use crate::kernels::argmax; +use crate::kernels::{argmax, argmin_value_float}; use crate::{ distance::{ l2::{l2_distance_batch, L2}, @@ -214,11 +214,11 @@ where let cluster_and_dists = match distance_type { DistanceType::L2 => data .par_chunks(dimension) - .map(|vec| argmin_value(l2_distance_batch(vec, centroids, dimension))) + .map(|vec| argmin_value_float(l2_distance_batch(vec, centroids, dimension))) .collect::>(), DistanceType::Dot => data .par_chunks(dimension) - .map(|vec| argmin_value(dot_distance_batch(vec, centroids, dimension))) + .map(|vec| argmin_value_float(dot_distance_batch(vec, centroids, dimension))) .collect::>(), _ => { panic!( @@ -500,8 +500,8 @@ impl KMeans { last_membership = Some(membership); if (loss - last_loss).abs() / last_loss < params.tolerance { info!( - "KMeans training: converged at iteration {} / {}, redo={}", - i, params.max_iters, redo + "KMeans training: converged at iteration {} / {}, redo={}, loss={}, last_loss={}, loss_diff={}", + i, params.max_iters, redo, loss, last_loss, (loss - last_loss).abs() / last_loss ); break; } From 25ea7fb29fb55492efa33edabf1bbebda40d426c Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sun, 23 Jun 2024 13:59:48 -0700 Subject: [PATCH 05/13] perf: use faster kmean find partition routing for pq assignment (#2515) * Added benchmark for pq assignment to use fast routing * Use compute_partition for fast kmeans find partition --- rust/lance-index/Cargo.toml | 4 +++ rust/lance-index/benches/pq_assignment.rs | 44 +++++++++++++++++++++++ rust/lance-index/src/vector/pq.rs | 11 +++--- rust/lance-index/src/vector/pq/utils.rs | 1 + rust/lance-linalg/src/kmeans.rs | 36 ++++++++++++------- 5 files changed, 77 insertions(+), 19 deletions(-) create mode 100644 rust/lance-index/benches/pq_assignment.rs diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 2f88fa684f..0991ae6e00 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -76,6 +76,10 @@ harness = false name = "pq_dist_table" harness = false +[[bench]] +name = "pq_assignment" +harness = false + [[bench]] name = "hnsw" harness = false diff --git a/rust/lance-index/benches/pq_assignment.rs b/rust/lance-index/benches/pq_assignment.rs new file mode 100644 index 0000000000..bf448633f7 --- /dev/null +++ b/rust/lance-index/benches/pq_assignment.rs @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmark of Building PQ code from Dense Vectors. + +use std::sync::Arc; + +use arrow_array::{types::Float32Type, FixedSizeListArray}; +use criterion::{criterion_group, criterion_main, Criterion}; +use lance_arrow::FixedSizeListArrayExt; +use lance_index::vector::pq::{ProductQuantizer, ProductQuantizerImpl}; +use lance_linalg::distance::DistanceType; +use lance_testing::datagen::generate_random_array_with_seed; + +const PQ: usize = 96; +const DIM: usize = 1536; +const TOTAL: usize = 32 * 1024; + +fn pq_transform(c: &mut Criterion) { + let codebook = Arc::new(generate_random_array_with_seed::( + 256 * DIM, + [88; 32], + )); + + let vectors = generate_random_array_with_seed::(DIM * TOTAL, [3; 32]); + let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + + for dt in [DistanceType::L2, DistanceType::Dot].iter() { + let pq = ProductQuantizerImpl::::new(PQ, 8, DIM, codebook.clone(), *dt); + + c.bench_function(format!("{},{}", dt, TOTAL).as_str(), |b| { + b.iter(|| { + let _ = pq.transform(&fsl).unwrap(); + }) + }); + } +} + +criterion_group!( + name=benches; + config = Criterion::default().significance_level(0.1).sample_size(10); + targets = pq_transform); + +criterion_main!(benches); diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index 37473aa0b7..9eee17c681 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -14,7 +14,7 @@ use lance_arrow::*; use lance_core::{Error, Result}; use lance_linalg::distance::{dot_distance_batch, l2_distance_batch, DistanceType, Dot, L2}; use lance_linalg::kernels::argmin_value_float; -use lance_linalg::kmeans::kmeans_find_partitions; +use lance_linalg::kmeans::compute_partition; use lance_linalg::{distance::MetricType, MatrixView}; use rayon::prelude::*; use snafu::{location, Location}; @@ -360,14 +360,15 @@ where location: location!(), })?; + let sub_dim = dim / num_sub_vectors; let values = flatten_data .as_slice() .par_chunks(dim) .map(|vector| { vector - .chunks_exact(dim / num_sub_vectors) + .chunks_exact(sub_dim) .enumerate() - .flat_map(|(sub_idx, sub_vector)| { + .map(|(sub_idx, sub_vector)| { let centroids = get_sub_vector_centroids( codebook.as_slice(), dim, @@ -375,9 +376,7 @@ where num_sub_vectors, sub_idx, ); - let parts = kmeans_find_partitions(centroids, sub_vector, 1, distance_type) - .expect("kmeans_find_partitions failed"); - parts.values().iter().map(|v| *v as u8).collect::>() + compute_partition(centroids, sub_vector, distance_type).map(|v| v as u8) }) .collect::>() }) diff --git a/rust/lance-index/src/vector/pq/utils.rs b/rust/lance-index/src/vector/pq/utils.rs index 08161947ce..80d5ff34d8 100644 --- a/rust/lance-index/src/vector/pq/utils.rs +++ b/rust/lance-index/src/vector/pq/utils.rs @@ -54,6 +54,7 @@ pub fn num_centroids(num_bits: impl Into) -> usize { 2_usize.pow(num_bits.into()) } +#[inline] pub fn get_sub_vector_centroids( codebook: &[T], dimension: usize, diff --git a/rust/lance-linalg/src/kmeans.rs b/rust/lance-linalg/src/kmeans.rs index 85990dd3c6..80f3735b04 100644 --- a/rust/lance-linalg/src/kmeans.rs +++ b/rust/lance-linalg/src/kmeans.rs @@ -685,22 +685,32 @@ pub fn compute_partitions( let dimension = dimension.as_(); vectors .par_chunks(dimension) - .map(|vec| { - argmin_value(match distance_type { - DistanceType::L2 => l2_distance_batch(vec, centroids, dimension), - DistanceType::Dot => dot_distance_batch(vec, centroids, dimension), - _ => { - panic!( - "KMeans::find_partitions: {} is not supported", - distance_type - ); - } - }) - .map(|(idx, _)| idx) - }) + .map(|vec| compute_partition(centroids, vec, distance_type)) .collect::>() } +#[inline] +pub fn compute_partition( + centroids: &[T], + vector: &[T], + distance_type: DistanceType, +) -> Option { + match distance_type { + DistanceType::L2 => { + argmin_value_float(l2_distance_batch(vector, centroids, vector.len())).map(|c| c.0) + } + DistanceType::Dot => { + argmin_value_float(dot_distance_batch(vector, centroids, vector.len())).map(|c| c.0) + } + _ => { + panic!( + "KMeans::compute_partition: distance type {} is not supported", + distance_type + ); + } + } +} + #[cfg(test)] mod tests { use std::iter::repeat; From 7cd11b8bd1e3eaf78ce40c26cccdbdcc21275ee1 Mon Sep 17 00:00:00 2001 From: LuQQiu Date: Sun, 23 Jun 2024 16:59:23 -0700 Subject: [PATCH 06/13] chore(java): enable separate java builds (#2457) The java build with full java 8 / 11 / 17 compile and tests takes a long time. This PR separate the builds so that they can be paralleled --------- Co-authored-by: Lei Xu --- .github/workflows/java.yml | 105 +++++++++--------- java/core/lance-jni/src/fragment.rs | 7 +- .../main/java/com/lancedb/lance/Dataset.java | 2 +- 3 files changed, 58 insertions(+), 56 deletions(-) diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index a7a533fd8a..5c245ed313 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -1,4 +1,5 @@ name: Build and Run Java JNI Tests + on: push: branches: @@ -8,6 +9,7 @@ on: - java/** - rust/** - .github/workflows/java.yml + env: # This env var is used by Swatinem/rust-cache@v2 for the cache # key, so we set it to make sure it is always consistent. @@ -20,74 +22,75 @@ env: # CI builds are faster with incremental disabled. CARGO_INCREMENTAL: "0" CARGO_BUILD_JOBS: "1" + jobs: - linux-build: + rust-clippy-fmt: runs-on: ubuntu-22.04 - name: ubuntu-22.04 + Java 11 & 17 + name: Rust Clippy and Fmt Check defaults: run: - working-directory: ./java + working-directory: ./java/core/lance-jni steps: - name: Checkout repository uses: actions/checkout@v4 - uses: Swatinem/rust-cache@v2 with: - workspaces: java/java-jni + workspaces: java/core/lance-jni + - name: Install dependencies + run: | + sudo apt update + sudo apt install -y protobuf-compiler libssl-dev - name: Run cargo fmt run: cargo fmt --check - working-directory: ./java/core/lance-jni + - name: Rust Clippy + run: cargo clippy --all-targets -- -D warnings + + build-and-test-java: + runs-on: ubuntu-22.04 + strategy: + matrix: + java-version: [8, 11, 17] + name: Build and Test with Java ${{ matrix.java-version }} + defaults: + run: + working-directory: ./java + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + with: + workspaces: java/core/lance-jni - name: Install dependencies run: | sudo apt update sudo apt install -y protobuf-compiler libssl-dev - - name: Install Java 17 - uses: actions/setup-java@v4 - with: - distribution: temurin - java-version: 17 - cache: "maven" - - run: echo "JAVA_17=$JAVA_HOME" >> $GITHUB_ENV - - name: Install Java 8 + - name: Set up Java ${{ matrix.java-version }} uses: actions/setup-java@v4 with: distribution: temurin - java-version: 8 + java-version: ${{ matrix.java-version }} cache: "maven" - - run: echo "JAVA_8=$JAVA_HOME" >> $GITHUB_ENV - - name: Install Java 11 - uses: actions/setup-java@v4 - with: - distribution: temurin - java-version: 11 - cache: "maven" - - name: Java Style Check - run: mvn checkstyle:check - - name: Rust Clippy - working-directory: java/core/lance-jni - run: cargo clippy --all-targets -- -D warnings - - name: Running tests with Java 11 - run: mvn clean test - - name: Running tests with Java 8 - run: JAVA_HOME=$JAVA_8 mvn clean test - - name: Running tests with Java 17 + - name: Running tests with Java ${{ matrix.java-version }} run: | - export JAVA_TOOL_OPTIONS="$JAVA_TOOL_OPTIONS \ - -XX:+IgnoreUnrecognizedVMOptions \ - --add-opens=java.base/java.lang=ALL-UNNAMED \ - --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \ - --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \ - --add-opens=java.base/java.io=ALL-UNNAMED \ - --add-opens=java.base/java.net=ALL-UNNAMED \ - --add-opens=java.base/java.nio=ALL-UNNAMED \ - --add-opens=java.base/java.util=ALL-UNNAMED \ - --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \ - --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \ - --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED \ - --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \ - --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \ - --add-opens=java.base/sun.security.action=ALL-UNNAMED \ - --add-opens=java.base/sun.util.calendar=ALL-UNNAMED \ - --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED \ - -Djdk.reflect.useDirectMethodHandle=false \ - -Dio.netty.tryReflectionSetAccessible=true" - JAVA_HOME=$JAVA_17 mvn clean test + if [ "${{ matrix.java-version }}" == "17" ]; then + export JAVA_TOOL_OPTIONS="$JAVA_TOOL_OPTIONS \ + -XX:+IgnoreUnrecognizedVMOptions \ + --add-opens=java.base/java.lang=ALL-UNNAMED \ + --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \ + --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \ + --add-opens=java.base/java.io=ALL-UNNAMED \ + --add-opens=java.base/java.net=ALL-UNNAMED \ + --add-opens=java.base/java.nio=ALL-UNNAMED \ + --add-opens=java.base/java.util=ALL-UNNAMED \ + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \ + --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \ + --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED \ + --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \ + --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \ + --add-opens=java.base/sun.security.action=ALL-UNNAMED \ + --add-opens=java.base/sun.util.calendar=ALL-UNNAMED \ + --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED \ + -Djdk.reflect.useDirectMethodHandle=false \ + -Dio.netty.tryReflectionSetAccessible=true" + fi + mvn clean test diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index 24d9863b22..2265822c11 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -34,10 +34,6 @@ use crate::{ RT, }; -/////////////////// -// Write Methods // -/////////////////// - ////////////////// // Read Methods // ////////////////// @@ -70,6 +66,9 @@ fn inner_count_rows_native( Ok(res) } +/////////////////// +// Write Methods // +/////////////////// #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local>( mut env: JNIEnv<'local>, diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 2e7df86812..7228684264 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -213,7 +213,7 @@ public long latestVersion() { /** * Count the number of rows in the dataset. * - * @return num of rows. + * @return num of rows */ public int countRows() { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { From e4122a695b1691d50364de4f211ff97fe5caf365 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Mon, 24 Jun 2024 11:04:46 -0700 Subject: [PATCH 07/13] fix: allow row id only in v2 (#2481) Closes #2480 --- rust/lance/src/dataset/fragment.rs | 82 +++++++++++++----------------- rust/lance/src/dataset/optimize.rs | 38 +++++++++++--- rust/lance/src/dataset/scanner.rs | 36 +++++++++++++ 3 files changed, 103 insertions(+), 53 deletions(-) diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index a3930a42e9..7bb2192b97 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -442,7 +442,7 @@ impl FileFragment { with_row_id: bool, with_row_address: bool, ) -> Result { - let open_files = self.open_readers(projection, with_row_id, with_row_address); + let open_files = self.open_readers(projection); let deletion_vec_load = self.load_deletion_vector(&self.dataset.object_store, &self.metadata); @@ -460,10 +460,10 @@ impl FileFragment { let deletion_vec = deletion_vec?; let row_id_sequence = row_id_sequence?; - if opened_files.is_empty() { + if opened_files.is_empty() && !with_row_id && !with_row_address { return Err(Error::io( format!( - "Does not find any data file for schema: {}\nfragment_id={}", + "Did not find any data files for schema: {}\nfragment_id={}", projection, self.id() ), @@ -471,6 +471,8 @@ impl FileFragment { )); } + let num_physical_rows = self.physical_rows().await?; + let mut reader = FragmentReader::try_new( self.id(), deletion_vec, @@ -478,6 +480,7 @@ impl FileFragment { opened_files, ArrowSchema::from(projection), self.count_rows().await?, + num_physical_rows, )?; if with_row_id { @@ -498,8 +501,6 @@ impl FileFragment { &self, data_file: &DataFile, projection: Option<&Schema>, - with_row_id: bool, - with_row_address: bool, ) -> Result, Arc)>> { let full_schema = self.dataset.schema(); // The data file may contain fields that are not part of the dataset any longer, remove those @@ -510,7 +511,7 @@ impl FileFragment { if data_file.is_legacy_file() { let max_field_id = data_file.fields.iter().max().unwrap(); - if with_row_id || with_row_address || !schema_per_file.fields.is_empty() { + if !schema_per_file.fields.is_empty() { let path = self.dataset.data_dir().child(data_file.path.as_str()); let field_id_offset = Self::get_field_id_offset(data_file); let reader = FileReader::try_new_with_fragment_id( @@ -566,17 +567,10 @@ impl FileFragment { async fn open_readers( &self, projection: &Schema, - with_row_id: bool, - with_row_address: bool, ) -> Result, Arc)>> { let mut opened_files = vec![]; - for (i, data_file) in self.metadata.files.iter().enumerate() { - // TODO: do we still need to do this? - let with_row_id = (with_row_id || with_row_address) && i == 0; - if let Some((reader, schema)) = self - .open_reader(data_file, Some(projection), with_row_id, with_row_address) - .await? - { + for data_file in &self.metadata.files { + if let Some((reader, schema)) = self.open_reader(data_file, Some(projection)).await? { opened_files.push((reader, schema)); } } @@ -668,16 +662,16 @@ impl FileFragment { // Just open any file. All of them should have same size. let some_file = &self.metadata.files[0]; - let (reader, _) = self - .open_reader(some_file, None, false, false) - .await? - .ok_or_else(|| Error::Internal { - message: format!( - "The data file {} did not have any fields contained in the dataset schema", - some_file.path - ), - location: location!(), - })?; + let (reader, _) = + self.open_reader(some_file, None) + .await? + .ok_or_else(|| Error::Internal { + message: format!( + "The data file {} did not have any fields contained in the dataset schema", + some_file.path + ), + location: location!(), + })?; Ok(reader.len() as usize) } @@ -762,16 +756,13 @@ impl FileFragment { } let get_lengths = self.metadata.files.iter().map(|data_file| async move { - let (reader, _) = self - .open_reader(data_file, None, false, false) - .await? - .ok_or_else(|| { - Error::corrupt_file( - self.dataset.data_dir().child(data_file.path.clone()), - "did not have any fields in common with the dataset schema", - location!(), - ) - })?; + let (reader, _) = self.open_reader(data_file, None).await?.ok_or_else(|| { + Error::corrupt_file( + self.dataset.data_dir().child(data_file.path.clone()), + "did not have any fields in common with the dataset schema", + location!(), + ) + })?; Result::Ok(reader.len() as usize) }); let get_lengths = try_join_all(get_lengths); @@ -1207,8 +1198,11 @@ pub struct FragmentReader { /// If false, deleted rows will be removed from the batch, requiring a copy make_deletions_null: bool, - // total number of rows in the fragment + // total number of real rows in the fragment (num_physical_rows - num_deleted_rows) num_rows: usize, + + // total number of physical rows in the fragment (all rows, ignoring deletions) + num_physical_rows: usize, } // Custom clone impl needed because it is not easy to clone Box @@ -1231,6 +1225,7 @@ impl Clone for FragmentReader { with_row_addr: self.with_row_addr, make_deletions_null: self.make_deletions_null, num_rows: self.num_rows, + num_physical_rows: self.num_physical_rows, } } } @@ -1264,15 +1259,9 @@ impl FragmentReader { readers: Vec<(Box, Arc)>, output_schema: ArrowSchema, num_rows: usize, + num_physical_rows: usize, ) -> Result { - if readers.is_empty() { - return Err(Error::io( - "Cannot create FragmentReader with zero readers".to_string(), - location!(), - )); - } - - if let Some(legacy_reader) = readers[0].0.as_legacy_opt() { + if let Some(legacy_reader) = readers.first().and_then(|reader| reader.0.as_legacy_opt()) { let num_batches = legacy_reader.num_batches(); for reader in readers.iter().skip(1) { if let Some(other_legacy) = reader.0.as_legacy_opt() { @@ -1301,6 +1290,7 @@ impl FragmentReader { with_row_addr: false, make_deletions_null: false, num_rows, + num_physical_rows, }) } @@ -1350,7 +1340,7 @@ impl FragmentReader { /// use streams, the updater still needs to know the batch size in v1 so that it can create /// files with the same batch size. pub(crate) fn legacy_num_rows_in_batch(&self, batch_id: u32) -> Option { - if let Some(legacy_reader) = self.readers[0].0.as_legacy_opt() { + if let Some(legacy_reader) = self.readers.first().and_then(|r| r.0.as_legacy_opt()) { if batch_id < legacy_reader.num_batches() as u32 { Some(legacy_reader.num_rows_in_batch(batch_id as i32) as u32) } else { @@ -1513,7 +1503,7 @@ impl FragmentReader { batch_size: u32, read_fn: impl Fn(&dyn GenericFileReader, &Arc) -> Result, ) -> Result { - let total_num_rows = self.readers[0].0.len(); + let total_num_rows = self.num_physical_rows as u32; // Note that the fragment length might be considerably smaller if there are deleted rows. // E.g. if a fragment has 100 rows but rows 0..10 are deleted we still need to make // sure it is valid to read / take 0..100 diff --git a/rust/lance/src/dataset/optimize.rs b/rust/lance/src/dataset/optimize.rs index ed09a7999b..c32c4fa906 100644 --- a/rust/lance/src/dataset/optimize.rs +++ b/rust/lance/src/dataset/optimize.rs @@ -916,6 +916,7 @@ mod tests { use arrow_array::{Float32Array, Int64Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use arrow_select::concat::concat_batches; + use rstest::rstest; use tempfile::tempdir; use super::*; @@ -1113,8 +1114,9 @@ mod tests { .unwrap() } + #[rstest] #[tokio::test] - async fn test_compact_empty() { + async fn test_compact_empty(#[values(false, true)] use_legacy_format: bool) { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -1122,7 +1124,16 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); let reader = RecordBatchIterator::new(vec![].into_iter().map(Ok), Arc::new(schema)); - let mut dataset = Dataset::write(reader, test_uri, None).await.unwrap(); + let mut dataset = Dataset::write( + reader, + test_uri, + Some(WriteParams { + use_legacy_format, + ..Default::default() + }), + ) + .await + .unwrap(); let plan = plan_compaction(&dataset, &CompactionOptions::default()) .await @@ -1137,8 +1148,9 @@ mod tests { assert_eq!(dataset.manifest.version, 1); } + #[rstest] #[tokio::test] - async fn test_compact_all_good() { + async fn test_compact_all_good(#[values(false, true)] use_legacy_format: bool) { // Compact a table with nothing to do let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -1148,6 +1160,7 @@ mod tests { // Just one file let write_params = WriteParams { max_rows_per_file: 10_000, + use_legacy_format, ..Default::default() }; let dataset = Dataset::write(reader, test_uri, Some(write_params)) @@ -1165,6 +1178,7 @@ mod tests { let write_params = WriteParams { max_rows_per_file: 3_000, max_rows_per_group: 1_000, + use_legacy_format, mode: WriteMode::Overwrite, ..Default::default() }; @@ -1217,8 +1231,9 @@ mod tests { } } + #[rstest] #[tokio::test] - async fn test_compact_many() { + async fn test_compact_many(#[values(false, true)] use_legacy_format: bool) { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -1228,6 +1243,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(0, 1200))], data.schema()); let write_params = WriteParams { max_rows_per_file: 400, + use_legacy_format, ..Default::default() }; Dataset::write(reader, test_uri, Some(write_params)) @@ -1238,6 +1254,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(1200, 2000))], data.schema()); let write_params = WriteParams { max_rows_per_file: 1000, + use_legacy_format, mode: WriteMode::Append, ..Default::default() }; @@ -1255,6 +1272,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(3200, 600))], data.schema()); let write_params = WriteParams { max_rows_per_file: 300, + use_legacy_format, mode: WriteMode::Append, ..Default::default() }; @@ -1358,8 +1376,9 @@ mod tests { assert_eq!(fragment_ids, vec![3, 7, 8, 9, 10]); } + #[rstest] #[tokio::test] - async fn test_compact_data_files() { + async fn test_compact_data_files(#[values(false, true)] use_legacy_format: bool) { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); @@ -1370,6 +1389,7 @@ mod tests { let write_params = WriteParams { max_rows_per_file: 5_000, max_rows_per_group: 1_000, + use_legacy_format, ..Default::default() }; let mut dataset = Dataset::write(reader, test_uri, Some(write_params)) @@ -1439,8 +1459,9 @@ mod tests { assert_eq!(scanned_data, data); } + #[rstest] #[tokio::test] - async fn test_compact_deletions() { + async fn test_compact_deletions(#[values(false, true)] use_legacy_format: bool) { // For files that have few rows, we don't want to compact just 1 since // that won't do anything. But if there are deletions to materialize, // we want to do groups of 1. This test checks that. @@ -1453,6 +1474,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(0, 1000))], data.schema()); let write_params = WriteParams { max_rows_per_file: 1000, + use_legacy_format, ..Default::default() }; let mut dataset = Dataset::write(reader, test_uri, Some(write_params)) @@ -1490,8 +1512,9 @@ mod tests { assert!(fragments[0].metadata.deletion_file.is_none()); } + #[rstest] #[tokio::test] - async fn test_compact_distributed() { + async fn test_compact_distributed(#[values(false, true)] use_legacy_format: bool) { // Can run the tasks independently // Can provide subset of tasks to commit_compaction // Once committed, can't commit remaining tasks @@ -1504,6 +1527,7 @@ mod tests { let reader = RecordBatchIterator::new(vec![Ok(data.slice(0, 9000))], data.schema()); let write_params = WriteParams { max_rows_per_file: 1000, + use_legacy_format, ..Default::default() }; let mut dataset = Dataset::write(reader, test_uri, Some(write_params)) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index dcc2fe1e38..ff34d4d94b 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -2208,6 +2208,42 @@ mod test { assert_eq!(expected_i, actual_i); } + #[rstest] + #[tokio::test] + async fn test_only_row_id(#[values(false, true)] use_legacy_format: bool) { + let test_ds = TestVectorDataset::new(use_legacy_format).await.unwrap(); + let dataset = &test_ds.dataset; + + let mut scan = dataset.scan(); + scan.project::<&str>(&[]).unwrap().with_row_id(); + + let results = scan + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + assert_eq!(results.len(), 1); + let batch = &results[0]; + + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.num_rows(), 400); + assert_eq!( + batch.schema().as_ref(), + &ArrowSchema::new(vec![ArrowField::new(ROW_ID, DataType::UInt64, true,)]) + ); + + let expected_row_ids: Vec = (0..400).map(|i| i as u64).collect(); + let actual_row_ids: Vec = as_primitive_array::(batch.column(0).as_ref()) + .values() + .iter() + .copied() + .collect(); + assert_eq!(expected_row_ids, actual_row_ids); + } + #[tokio::test] async fn test_scan_unordered_with_row_id() { // This test doesn't make sense for v2 files, there is no way to get an out-of-order scan From 1ee75d306f4668a874441d3518c90dc8032b16d4 Mon Sep 17 00:00:00 2001 From: Nick Darvey Date: Tue, 25 Jun 2024 04:38:30 +1000 Subject: [PATCH 08/13] feat: enable aarch64-pc-windows builds (#2512) Minimally enables `aarch64-pc-windows` by always returning false for NEON FP16 support checks. --- rust/lance-core/src/utils/cpu.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rust/lance-core/src/utils/cpu.rs b/rust/lance-core/src/utils/cpu.rs index b2e0ba31bc..4427922dd1 100644 --- a/rust/lance-core/src/utils/cpu.rs +++ b/rust/lance-core/src/utils/cpu.rs @@ -92,6 +92,14 @@ mod aarch64 { } } +#[cfg(all(target_arch = "aarch64", target_os = "windows"))] +mod aarch64 { + pub fn has_neon_f16_support() -> bool { + // https://github.com/lancedb/lance/issues/2411 + false + } +} + #[cfg(target_arch = "loongarch64")] mod loongarch64 { pub fn has_lsx_support() -> bool { From 95f98a90b5710df0e0492b89796979877a93c14b Mon Sep 17 00:00:00 2001 From: Raunak Shah Date: Mon, 24 Jun 2024 14:05:50 -0700 Subject: [PATCH 09/13] refactor: combined capacity updates and decoding for all physical decoders (#2508) - Combines `update_capacity` and `decode_into` into a single function `decode` for all physical (primitive) decoders (`BasicDecoder`, `BinaryPageDecoder`, `ValuePageDecoder`, `FixedListDecoder`, `BitmapDecoder`) - As a result some decoders don't require explicit capacity allocation at a certain level. Should enable decoders to have more flexibility in their decode pipeline (required for dictionary encoding) - Buffers built up from leaves of the encoding tree and combined instead of being recursively passed down (like earlier) - Rename PhysicalPageDecoder to PrimitivePageDecoder --- rust/lance-encoding/src/decoder.rs | 52 +++++++------- .../src/encodings/logical/primitive.rs | 45 ++++-------- .../src/encodings/physical/basic.rs | 69 ++++++++----------- .../src/encodings/physical/binary.rs | 66 +++++------------- .../src/encodings/physical/bitmap.rs | 36 ++++------ .../src/encodings/physical/fixed_size_list.rs | 32 +++------ .../src/encodings/physical/value.rs | 31 ++++----- 7 files changed, 117 insertions(+), 214 deletions(-) diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index c684c89635..2f5d2c06cc 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -1144,50 +1144,44 @@ impl BatchDecodeStream { /// the decode task for batch 0 and the decode task for batch 1. /// /// See [`crate::decoder`] for more information -pub trait PhysicalPageDecoder: Send + Sync { - /// Calculates and updates the capacity required to represent the requested data +pub trait PrimitivePageDecoder: Send + Sync { + /// Decode data into buffers + /// + /// This may be a simple zero-copy from a disk buffer or could involve complex decoding + /// such as decompressing from some compressed representation. /// /// Capacity is stored as a tuple of (num_bytes: u64, is_needed: bool). The `is_needed` /// portion only needs to be updated if the encoding has some concept of an "optional" /// buffer. /// - /// The decoder should look at `rows_to_skip` and `num_rows` and then calculate how - /// many bytes of data are needed. It should then update the first part of the tuple. - /// - /// Note: Most encodings deal with a single buffer. They may have multiple input buffers - /// but they only have a single output buffer. The current exception to this rule is the - /// `basic` encoding which has an output "validity" buffer and an output "values" buffers. - /// We may find there are other such exceptions. + /// Encodings can have any number of input or output buffers. For example, a dictionary + /// decoding will convert two buffers (indices + dictionary) into a single buffer /// - /// # Arguments + /// Binary decodings have two output buffers (one for values, one for offsets) /// - /// * `rows_to_skip` - how many rows to skip (within the page) before decoding - /// * `num_rows` - how many rows to decode - /// * `buffers` - A mutable slice of "capacities" (as described above), one per buffer - /// * `all_null` - A mutable bool, set to true if a decoder determines all values are null - fn update_capacity( - &self, - rows_to_skip: u32, - num_rows: u32, - buffers: &mut [(u64, bool)], - all_null: &mut bool, - ); - /// Decodes the data into the requested buffers. + /// Other decodings could even expand the # of output buffers. For example, we could decode + /// fixed size strings into variable length strings going from one input buffer to multiple output + /// buffers. /// - /// You can assume that the capacity will have already been configured on the `BytesMut` - /// according to the capacity calculated in [`PhysicalPageDecoder::update_capacity`] + /// Each Arrow data type typically has a fixed structure of buffers and the encoding chain will + /// generally end at one of these structures. However, intermediate structures may exist which + /// do not correspond to any Arrow type at all. For example, a bitpacking encoding will deal + /// with buffers that have bits-per-value that is not a multiple of 8. /// + /// The `primitive_array_from_buffers` method has an expected buffer layout for each arrow + /// type (order matters) and encodings that aim to decode into arrow types should respect + /// this layout. /// # Arguments /// /// * `rows_to_skip` - how many rows to skip (within the page) before decoding /// * `num_rows` - how many rows to decode - /// * `dest_buffers` - the output buffers to decode into - fn decode_into( + /// * `all_null` - A mutable bool, set to true if a decoder determines all values are null + fn decode( &self, rows_to_skip: u32, num_rows: u32, - dest_buffers: &mut [BytesMut], - ) -> Result<()>; + all_null: &mut bool, + ) -> Result>; fn num_buffers(&self) -> u32; } @@ -1217,7 +1211,7 @@ pub trait PageScheduler: Send + Sync + std::fmt::Debug { ranges: &[Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>>; + ) -> BoxFuture<'static, Result>>; } /// Contains the context for a scheduler diff --git a/rust/lance-encoding/src/encodings/logical/primitive.rs b/rust/lance-encoding/src/encodings/logical/primitive.rs index 003a179943..1f1457f6f2 100644 --- a/rust/lance-encoding/src/encodings/logical/primitive.rs +++ b/rust/lance-encoding/src/encodings/logical/primitive.rs @@ -29,7 +29,7 @@ use lance_core::{Error, Result}; use crate::{ decoder::{ DecodeArrayTask, FieldScheduler, FilterExpression, LogicalPageDecoder, NextDecodeTask, - PageInfo, PageScheduler, PhysicalPageDecoder, ScheduledScanLine, SchedulerContext, + PageInfo, PageScheduler, PrimitivePageDecoder, ScheduledScanLine, SchedulerContext, SchedulingJob, }, encoder::{ArrayEncodingStrategy, EncodeTask, EncodedColumn, EncodedPage, FieldEncoder}, @@ -211,15 +211,15 @@ impl FieldScheduler for PrimitiveFieldScheduler { pub struct PrimitiveFieldDecoder { data_type: DataType, - unloaded_physical_decoder: Option>>>, - physical_decoder: Option>, + unloaded_physical_decoder: Option>>>, + physical_decoder: Option>, num_rows: u32, rows_drained: u32, } impl PrimitiveFieldDecoder { pub fn new_from_data( - physical_decoder: Arc, + physical_decoder: Arc, data_type: DataType, num_rows: u32, ) -> Self { @@ -246,46 +246,25 @@ impl Debug for PrimitiveFieldDecoder { struct PrimitiveFieldDecodeTask { rows_to_skip: u32, rows_to_take: u32, - physical_decoder: Arc, + physical_decoder: Arc, data_type: DataType, } impl DecodeArrayTask for PrimitiveFieldDecodeTask { fn decode(self: Box) -> Result { - // We start by assuming that no buffers are required. The number of buffers needed is based - // on the data type. Most data types need two buffers but each layer of fixed-size-list, for - // example, adds another validity buffer - let mut capacities = vec![(0, false); self.physical_decoder.num_buffers() as usize]; let mut all_null = false; - self.physical_decoder.update_capacity( - self.rows_to_skip, - self.rows_to_take, - &mut capacities, - &mut all_null, - ); + + // The number of buffers needed is based on the data type. + // Most data types need two buffers but each layer of fixed-size-list, for + // example, adds another validity buffer. + let bufs = + self.physical_decoder + .decode(self.rows_to_skip, self.rows_to_take, &mut all_null)?; if all_null { return Ok(new_null_array(&self.data_type, self.rows_to_take as usize)); } - // At this point we know the size needed for each buffer - let mut bufs = capacities - .into_iter() - .map(|(num_bytes, is_needed)| { - // Only allocate the validity buffer if it is needed, otherwise we - // create an empty BytesMut (does not require allocation) - if is_needed { - BytesMut::with_capacity(num_bytes as usize) - } else { - BytesMut::default() - } - }) - .collect::>(); - - // Go ahead and fill the validity / values buffers - self.physical_decoder - .decode_into(self.rows_to_skip, self.rows_to_take, &mut bufs)?; - // Convert the two buffers into an Arrow array Self::primitive_array_from_buffers(&self.data_type, bufs, self.rows_to_take) } diff --git a/rust/lance-encoding/src/encodings/physical/basic.rs b/rust/lance-encoding/src/encodings/physical/basic.rs index 15fcea00e0..87c466eada 100644 --- a/rust/lance-encoding/src/encodings/physical/basic.rs +++ b/rust/lance-encoding/src/encodings/physical/basic.rs @@ -5,11 +5,12 @@ use std::sync::Arc; use arrow_array::{ArrayRef, BooleanArray}; use arrow_buffer::BooleanBuffer; +use bytes::BytesMut; use futures::{future::BoxFuture, FutureExt}; use log::trace; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, BufferEncoder, EncodedArray, EncodedArrayBuffer}, format::pb, EncodingsIo, @@ -20,21 +21,21 @@ use lance_core::Result; use super::buffers::BitmapBufferEncoder; struct DataDecoders { - validity: Box, - values: Box, + validity: Box, + values: Box, } enum DataNullStatus { // Neither validity nor values All, // Values only - None(Box), + None(Box), // Validity and values Some(DataDecoders), } impl DataNullStatus { - fn values_decoder(&self) -> Option<&dyn PhysicalPageDecoder> { + fn values_decoder(&self) -> Option<&dyn PrimitivePageDecoder> { match self { Self::All => None, Self::Some(decoders) => Some(decoders.values.as_ref()), @@ -124,7 +125,7 @@ impl PageScheduler for BasicPageScheduler { ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { let validity_future = match &self.mode { SchedulerNullStatus::None(_) | SchedulerNullStatus::All => None, SchedulerNullStatus::Some(schedulers) => Some(schedulers.validity.schedule_ranges( @@ -157,7 +158,7 @@ impl PageScheduler for BasicPageScheduler { } _ => unreachable!(), }; - Ok(Box::new(BasicPageDecoder { mode }) as Box) + Ok(Box::new(BasicPageDecoder { mode }) as Box) } .boxed() } @@ -167,51 +168,39 @@ struct BasicPageDecoder { mode: DataNullStatus, } -impl PhysicalPageDecoder for BasicPageDecoder { - fn update_capacity( +impl PrimitivePageDecoder for BasicPageDecoder { + fn decode( &self, rows_to_skip: u32, num_rows: u32, - buffers: &mut [(u64, bool)], all_null: &mut bool, - ) { - // No need to look at the validity decoder to know the dest buffer size since it is boolean - buffers[0].0 = arrow_buffer::bit_util::ceil(num_rows as usize, 8) as u64; - // The validity buffer is only required if we have some nulls - buffers[0].1 = matches!(self.mode, DataNullStatus::Some(_)); - if let Some(values) = self.mode.values_decoder() { - values.update_capacity(rows_to_skip, num_rows, &mut buffers[1..], all_null); - } else { - *all_null = true; - } - } - - fn decode_into( - &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [bytes::BytesMut], - ) -> Result<()> { - match &self.mode { + ) -> Result> { + let dest_buffers = match &self.mode { DataNullStatus::Some(decoders) => { - decoders - .validity - .decode_into(rows_to_skip, num_rows, &mut dest_buffers[..1])?; - decoders - .values - .decode_into(rows_to_skip, num_rows, &mut dest_buffers[1..])?; + let mut buffers = decoders.validity.decode(rows_to_skip, num_rows, all_null)?; // buffer 0 + let mut values_bytesmut = + decoders.values.decode(rows_to_skip, num_rows, all_null)?; // buffer 1 onwards + + buffers.append(&mut values_bytesmut); + buffers } // Either dest_buffers[0] is empty, in which case these are no-ops, or one of the // other pages needed the buffer, in which case we need to fill our section DataNullStatus::All => { - dest_buffers[0].fill(0); + let buffers = vec![BytesMut::default()]; + *all_null = true; + buffers } DataNullStatus::None(values) => { - dest_buffers[0].fill(1); - values.decode_into(rows_to_skip, num_rows, &mut dest_buffers[1..])?; + let mut dest_buffers = vec![BytesMut::default()]; + + let mut values_bytesmut = values.decode(rows_to_skip, num_rows, all_null)?; + dest_buffers.append(&mut values_bytesmut); + dest_buffers } - } - Ok(()) + }; + + Ok(dest_buffers) } fn num_buffers(&self) -> u32 { diff --git a/rust/lance-encoding/src/encodings/physical/binary.rs b/rust/lance-encoding/src/encodings/physical/binary.rs index 42937c30cd..6d2c47d704 100644 --- a/rust/lance-encoding/src/encodings/physical/binary.rs +++ b/rust/lance-encoding/src/encodings/physical/binary.rs @@ -11,12 +11,13 @@ use arrow_array::{ builder::{ArrayBuilder, Int32Builder, UInt32Builder, UInt8Builder}, Array, ArrayRef, Int32Array, UInt32Array, }; +use bytes::BytesMut; use futures::stream::StreamExt; use futures::{future::BoxFuture, stream::FuturesOrdered, FutureExt}; // use rand::seq::index; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, EncodedArray}, format::pb, EncodingsIo, @@ -54,7 +55,7 @@ impl PageScheduler for BinaryPageScheduler { ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { // ranges corresponds to row ranges that the user wants to fetch. // if user wants row range a..b // Case 1: if a != 0, we need indices a-1..b to decode @@ -101,7 +102,7 @@ impl PageScheduler for BinaryPageScheduler { let mut curr_range_idx = 0; let mut last = 0; while let Some(indices_page_decoder) = futures_ordered.next().await { - let indices: Arc = Arc::from(indices_page_decoder?); + let indices: Arc = Arc::from(indices_page_decoder?); // Build and run decode task for offsets let curr_indices_range = copy_indices_ranges[curr_range_idx].clone(); @@ -164,12 +165,12 @@ impl PageScheduler for BinaryPageScheduler { top_level_row, ); - let bytes_decoder: Box = bytes_page_decoder.await?; + let bytes_decoder: Box = bytes_page_decoder.await?; Ok(Box::new(BinaryPageDecoder { decoded_indices, bytes_decoder, - }) as Box) + }) as Box) } .boxed() } @@ -177,10 +178,10 @@ impl PageScheduler for BinaryPageScheduler { struct BinaryPageDecoder { decoded_indices: Arc, - bytes_decoder: Box, + bytes_decoder: Box, } -impl PhysicalPageDecoder for BinaryPageDecoder { +impl PrimitivePageDecoder for BinaryPageDecoder { // Continuing the example from BinaryPageScheduler // Suppose batch_size = 2. Then first, rows_to_skip=0, num_rows=2 // Need to scan 2 rows @@ -188,42 +189,16 @@ impl PhysicalPageDecoder for BinaryPageDecoder { // Allocate 8 bytes capacity. // Next rows_to_skip=2, num_rows=1 // Skip 8 bytes. Allocate 5 bytes capacity. - fn update_capacity( - &self, - rows_to_skip: u32, - num_rows: u32, - buffers: &mut [(u64, bool)], - all_null: &mut bool, - ) { - let offsets = self - .decoded_indices - .as_any() - .downcast_ref::() - .unwrap(); - - // 32 bits or 4 bytes per value. - buffers[0].0 = (num_rows as u64) * 4; - buffers[0].1 = true; - - let bytes_to_skip = offsets.value(rows_to_skip as usize); - let num_bytes = offsets.value((rows_to_skip + num_rows) as usize) - bytes_to_skip; - - self.bytes_decoder - .update_capacity(bytes_to_skip, num_bytes, &mut buffers[1..], all_null); - } - - // Continuing from update_capacity: - // When rows_to_skip=2, num_rows=1 // The normalized offsets are [0, 4, 8, 13] // We only need [8, 13] to decode in this case. // These need to be normalized in order to build the string later // So return [0, 5] - fn decode_into( + fn decode( &self, rows_to_skip: u32, num_rows: u32, - dest_buffers: &mut [bytes::BytesMut], - ) -> Result<()> { + all_null: &mut bool, + ) -> Result> { let offsets = self .decoded_indices .as_any() @@ -232,12 +207,15 @@ impl PhysicalPageDecoder for BinaryPageDecoder { let bytes_to_skip = offsets.value(rows_to_skip as usize); let num_bytes = offsets.value((rows_to_skip + num_rows) as usize) - bytes_to_skip; - let target_offsets = offsets.slice( rows_to_skip.try_into().unwrap(), (num_rows + 1).try_into().unwrap(), ); + let mut bytes_buffers = self + .bytes_decoder + .decode(bytes_to_skip, num_bytes, all_null)?; + // Normalize offsets let target_vec = target_offsets.values(); let normalized_array: PrimitiveArray = @@ -245,18 +223,10 @@ impl PhysicalPageDecoder for BinaryPageDecoder { let normalized_values = normalized_array.values(); let byte_slice = normalized_values.inner().deref(); + let mut dest_buffers = vec![BytesMut::from(byte_slice)]; + dest_buffers.append(&mut bytes_buffers); - // copy target_offsets into dest_buffers[0] - dest_buffers[0].extend_from_slice(byte_slice); - - // Copy decoded bytes into dest_buffers[1..] - // Currently an empty null buffer is the first one - // The actual bytes are in the second buffer - // Including the indices this results in 3 buffers in total - self.bytes_decoder - .decode_into(bytes_to_skip, num_bytes, &mut dest_buffers[1..])?; - - Ok(()) + Ok(dest_buffers) } fn num_buffers(&self) -> u32 { diff --git a/rust/lance-encoding/src/encodings/physical/bitmap.rs b/rust/lance-encoding/src/encodings/physical/bitmap.rs index b95be7580d..12fe355d88 100644 --- a/rust/lance-encoding/src/encodings/physical/bitmap.rs +++ b/rust/lance-encoding/src/encodings/physical/bitmap.rs @@ -11,7 +11,7 @@ use lance_core::Result; use log::trace; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, EncodingsIo, }; @@ -36,7 +36,7 @@ impl PageScheduler for DenseBitmapScheduler { ranges: &[Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { let mut min = u64::MAX; let mut max = 0; let chunk_reqs = ranges @@ -76,7 +76,7 @@ impl PageScheduler for DenseBitmapScheduler { length, }) .collect::>(); - Ok(Box::new(BitmapDecoder { chunks }) as Box) + Ok(Box::new(BitmapDecoder { chunks }) as Box) } .boxed() } @@ -92,24 +92,16 @@ struct BitmapDecoder { chunks: Vec, } -impl PhysicalPageDecoder for BitmapDecoder { - fn update_capacity( +impl PrimitivePageDecoder for BitmapDecoder { + fn decode( &self, - _rows_to_skip: u32, + rows_to_skip: u32, num_rows: u32, - buffers: &mut [(u64, bool)], _all_null: &mut bool, - ) { - buffers[0].0 = arrow_buffer::bit_util::ceil(num_rows as usize, 8) as u64; - buffers[0].1 = true; - } + ) -> Result> { + let num_bytes = arrow_buffer::bit_util::ceil(num_rows as usize, 8); + let mut dest_buffers = vec![BytesMut::with_capacity(num_bytes)]; - fn decode_into( - &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [BytesMut], - ) -> Result<()> { let mut rows_to_skip = rows_to_skip; let mut dest_builder = BooleanBufferBuilder::new(num_rows as usize); @@ -146,7 +138,7 @@ impl PhysicalPageDecoder for BitmapDecoder { // // It's a moot point at the moment since we don't support page bridging dest_buffers[0].copy_from_slice(bool_buffer.as_slice()); - Ok(()) + Ok(dest_buffers) } fn num_buffers(&self) -> u32 { @@ -157,9 +149,9 @@ impl PhysicalPageDecoder for BitmapDecoder { #[cfg(test)] mod tests { use arrow_schema::{DataType, Field}; - use bytes::{Bytes, BytesMut}; + use bytes::Bytes; - use crate::decoder::PhysicalPageDecoder; + use crate::decoder::PrimitivePageDecoder; use crate::encodings::physical::bitmap::BitmapData; use crate::testing::check_round_trip_encoding_random; @@ -189,8 +181,8 @@ mod tests { }, ], }; - let mut dest = vec![BytesMut::with_capacity(1)]; - let result = decoder.decode_into(5, 1, &mut dest); + + let result = decoder.decode(5, 1, &mut false); assert!(result.is_ok()); } } diff --git a/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs b/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs index 0d69f95aab..ce6a591226 100644 --- a/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs +++ b/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs @@ -4,12 +4,13 @@ use std::sync::Arc; use arrow_array::{cast::AsArray, ArrayRef}; +use bytes::BytesMut; use futures::{future::BoxFuture, FutureExt}; use lance_core::Result; use log::trace; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, EncodedArray}, format::pb, EncodingsIo, @@ -39,7 +40,7 @@ impl PageScheduler for FixedListScheduler { ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { let expanded_ranges = ranges .iter() .map(|range| (range.start * self.dimension)..(range.end * self.dimension)) @@ -61,42 +62,27 @@ impl PageScheduler for FixedListScheduler { Ok(Box::new(FixedListDecoder { items_decoder, dimension, - }) as Box) + }) as Box) } .boxed() } } pub struct FixedListDecoder { - items_decoder: Box, + items_decoder: Box, dimension: u32, } -impl PhysicalPageDecoder for FixedListDecoder { - fn update_capacity( +impl PrimitivePageDecoder for FixedListDecoder { + fn decode( &self, rows_to_skip: u32, num_rows: u32, - buffers: &mut [(u64, bool)], all_null: &mut bool, - ) { + ) -> Result> { let rows_to_skip = rows_to_skip * self.dimension; let num_rows = num_rows * self.dimension; - self.items_decoder - .update_capacity(rows_to_skip, num_rows, buffers, all_null); - } - - fn decode_into( - &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [bytes::BytesMut], - ) -> Result<()> { - let rows_to_skip = rows_to_skip * self.dimension; - let num_rows = num_rows * self.dimension; - self.items_decoder - .decode_into(rows_to_skip, num_rows, dest_buffers)?; - Ok(()) + self.items_decoder.decode(rows_to_skip, num_rows, all_null) } fn num_buffers(&self) -> u32 { diff --git a/rust/lance-encoding/src/encodings/physical/value.rs b/rust/lance-encoding/src/encodings/physical/value.rs index 3e8407d6fa..046438894c 100644 --- a/rust/lance-encoding/src/encodings/physical/value.rs +++ b/rust/lance-encoding/src/encodings/physical/value.rs @@ -3,7 +3,7 @@ use arrow_array::ArrayRef; use arrow_schema::DataType; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use futures::{future::BoxFuture, FutureExt}; use lance_arrow::DataTypeExt; use log::trace; @@ -13,7 +13,7 @@ use std::ops::Range; use std::sync::{Arc, Mutex}; use crate::{ - decoder::{PageScheduler, PhysicalPageDecoder}, + decoder::{PageScheduler, PrimitivePageDecoder}, encoder::{ArrayEncoder, BufferEncoder, EncodedArray, EncodedArrayBuffer}, format::pb, EncodingsIo, @@ -85,7 +85,7 @@ impl PageScheduler for ValuePageScheduler { ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, - ) -> BoxFuture<'static, Result>> { + ) -> BoxFuture<'static, Result>> { let (mut min, mut max) = (u64::MAX, 0); let byte_ranges = if self.compression_scheme == CompressionScheme::None { ranges @@ -139,7 +139,7 @@ impl PageScheduler for ValuePageScheduler { data: bytes, uncompressed_data: Arc::new(Mutex::new(None)), uncompressed_range_offsets: range_offsets, - }) as Box) + }) as Box) } .boxed() } @@ -203,27 +203,20 @@ impl ValuePageDecoder { } } -impl PhysicalPageDecoder for ValuePageDecoder { - fn update_capacity( +impl PrimitivePageDecoder for ValuePageDecoder { + fn decode( &self, - _rows_to_skip: u32, + rows_to_skip: u32, num_rows: u32, - buffers: &mut [(u64, bool)], _all_null: &mut bool, - ) { - buffers[0].0 = self.bytes_per_value * num_rows as u64; - buffers[0].1 = true; - } + ) -> Result> { + let num_bytes = self.bytes_per_value * num_rows as u64; - fn decode_into( - &self, - rows_to_skip: u32, - num_rows: u32, - dest_buffers: &mut [bytes::BytesMut], - ) -> Result<()> { let mut bytes_to_skip = rows_to_skip as u64 * self.bytes_per_value; let mut bytes_to_take = num_rows as u64 * self.bytes_per_value; + let mut dest_buffers = vec![BytesMut::with_capacity(num_bytes as usize)]; + let dest = &mut dest_buffers[0]; debug_assert!(dest.capacity() as u64 >= bytes_to_take); @@ -238,7 +231,7 @@ impl PhysicalPageDecoder for ValuePageDecoder { self.decode_buffer(buf, &mut bytes_to_skip, &mut bytes_to_take, dest); } } - Ok(()) + Ok(dest_buffers) } fn num_buffers(&self) -> u32 { From c4def70444977ca2a235f6524069b092afba5071 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Mon, 24 Jun 2024 14:16:20 -0700 Subject: [PATCH 10/13] feat: expose DatasetPreFilter, PreFilter, and FilterLoader to the public API (#2522) A [recent change](https://github.com/lancedb/lance/commit/36c08d586d51a77b0ef5b67adac86f709ff9241b) moved these types out of the public API of `lance_index` and into the private API of `lance`. This broke a customer (full disclosure, us) that was relying on these types. These types are well documented and straightforward. I think they were intended to be part of a public API. However, I could be wrong. Exposing them allows users to supply their own prefilter implementations when using `lance_index` utilities directly. I think it would be fair to call the API experimental if we want to but the prefilter has been stable for a bit. --- rust/lance/src/index.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 88f635377a..e4886ad312 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -36,7 +36,7 @@ use vector::ivf::v2::IVFIndex; pub(crate) mod append; pub(crate) mod cache; -pub(crate) mod prefilter; +pub mod prefilter; pub mod scalar; pub mod vector; From 816a58a262bc488d3a741c501051f98560b347a7 Mon Sep 17 00:00:00 2001 From: Lance Release Date: Mon, 24 Jun 2024 23:01:21 +0000 Subject: [PATCH 11/13] Bump version --- Cargo.toml | 30 +++++++++++++++--------------- python/Cargo.toml | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 79efd71405..b739f6e94c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.13.0" +version = "0.13.1" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -43,20 +43,20 @@ categories = [ rust-version = "1.78" [workspace.dependencies] -lance = { version = "=0.13.0", path = "./rust/lance" } -lance-arrow = { version = "=0.13.0", path = "./rust/lance-arrow" } -lance-core = { version = "=0.13.0", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.13.0", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.13.0", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.13.0", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.13.0", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.13.0", path = "./rust/lance-file" } -lance-index = { version = "=0.13.0", path = "./rust/lance-index" } -lance-io = { version = "=0.13.0", path = "./rust/lance-io" } -lance-linalg = { version = "=0.13.0", path = "./rust/lance-linalg" } -lance-table = { version = "=0.13.0", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.13.0", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.13.0", path = "./rust/lance-testing" } +lance = { version = "=0.13.1", path = "./rust/lance" } +lance-arrow = { version = "=0.13.1", path = "./rust/lance-arrow" } +lance-core = { version = "=0.13.1", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.13.1", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.13.1", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.13.1", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.13.1", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.13.1", path = "./rust/lance-file" } +lance-index = { version = "=0.13.1", path = "./rust/lance-index" } +lance-io = { version = "=0.13.1", path = "./rust/lance-io" } +lance-linalg = { version = "=0.13.1", path = "./rust/lance-linalg" } +lance-table = { version = "=0.13.1", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.13.1", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.13.1", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "51.0.0", optional = false, features = ["prettyprint"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index b35c7c4ea8..fddf1627fe 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.13.0" +version = "0.13.1" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" From 4a114b2e8d39496ae8fa34a8316faaef90c55e2d Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Mon, 24 Jun 2024 16:53:43 -0700 Subject: [PATCH 12/13] feat: enhance binary array encoding, make it the default (#2521) This adds support for the following things to the binary encoding: * Nulls * Large offsets * Different types (e.g. String vs Binary) In addition, I have changed the row limit on pages from `u32` to `u64`. Partly, this is because string arrays larger that 2GB are not unheard of and it may be nice in some cases to store them in one page. Also pages are infrequent enough that an extra 4 bytes is trivial (actually, protobuf already stores u32 and u64 the same). Another reason is that, with some encodings (e.g. constant / compressed bitmap / etc.) it's possible to have 2^64 rows without occupying a large amount of space. Note: this is technically a `.proto` change but in protobuf `u32 -> u64` is backwards compatible so it is not a huge issue. --- protos/encodings.proto | 1 + protos/file2.proto | 7 +- protos/table.proto | 4 +- rust/lance-encoding/src/decoder.rs | 96 +-- rust/lance-encoding/src/encoder.rs | 65 +- rust/lance-encoding/src/encodings/logical.rs | 1 - .../src/encodings/logical/binary.rs | 296 -------- .../src/encodings/logical/list.rs | 48 +- .../src/encodings/logical/primitive.rs | 150 ++-- .../src/encodings/logical/struct.rs | 101 +-- rust/lance-encoding/src/encodings/physical.rs | 43 +- .../src/encodings/physical/basic.rs | 6 +- .../src/encodings/physical/binary.rs | 699 ++++++++++++------ .../src/encodings/physical/bitmap.rs | 14 +- .../src/encodings/physical/fixed_size_list.rs | 12 +- .../src/encodings/physical/value.rs | 22 +- rust/lance-file/src/v2/reader.rs | 221 +----- 17 files changed, 729 insertions(+), 1057 deletions(-) delete mode 100644 rust/lance-encoding/src/encodings/logical/binary.rs diff --git a/protos/encodings.proto b/protos/encodings.proto index 36ce5a7234..ffd0216afe 100644 --- a/protos/encodings.proto +++ b/protos/encodings.proto @@ -183,6 +183,7 @@ message SimpleStruct {} message Binary { ArrayEncoding indices = 1; ArrayEncoding bytes = 2; + uint64 null_adjustment = 3; } // Encodings that decode into an Arrow array diff --git a/protos/file2.proto b/protos/file2.proto index 63fc379091..31c9b50d49 100644 --- a/protos/file2.proto +++ b/protos/file2.proto @@ -15,10 +15,9 @@ import "google/protobuf/empty.proto"; // // * Each Lance file contains between 0 and 4Gi columns // * Each column contains between 0 and 4Gi pages -// * Each page contains between 0 and 4Gi items +// * Each page contains between 0 and 2^64 items // * Different pages within a column can have different items counts -// * Columns may have more than 4Gi items, though this will require more than -// one page +// * Columns may have up to 2^64 items // * Different columns within a file can have different item counts // // The Lance file format does not have any notion of a type system or schemas. @@ -178,7 +177,7 @@ message ColumnMetadata { // may be empty. repeated uint64 buffer_sizes = 2; // Logical length (e.g. # rows) of the page - uint32 length = 3; + uint64 length = 3; // The encoding used to encode the page Encoding encoding = 4; } diff --git a/protos/table.proto b/protos/table.proto index 8daca05e1f..ff500cdddb 100644 --- a/protos/table.proto +++ b/protos/table.proto @@ -235,7 +235,7 @@ message DataFile { // - dimension: packed-struct (0): // - x: u32 (1) // - y: u32 (2) - // - path: string (3) + // - path: list (3) // - embedding: fsl<768> (4) // - fp64 // - borders: fsl<4> (5) @@ -249,7 +249,7 @@ message DataFile { // This reflects quite a few phenomenon: // - The packed struct is encoded into a single column and there is no top-level column // for the x or y fields - // - The string is encoded into two columns + // - The variable sized list is encoded into two columns // - The embedding is encoded into a single column (common for FSL of primitive) and there // is not "FSL column" // - The borders field actually does have an "FSL column" diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index 2f5d2c06cc..0ecf127c0a 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -222,6 +222,7 @@ use bytes::{Bytes, BytesMut}; use futures::future::BoxFuture; use futures::stream::{BoxStream, FuturesOrdered}; use futures::{FutureExt, StreamExt, TryStreamExt}; +use lance_arrow::DataTypeExt; use lance_core::datatypes::{Field, Schema}; use log::trace; use snafu::{location, Location}; @@ -230,9 +231,7 @@ use tokio::sync::mpsc::{self, unbounded_channel}; use lance_core::{Error, Result}; use tracing::instrument; -use crate::encoder::get_str_encoding_type; use crate::encoder::{values_column_encoding, EncodedBatch}; -use crate::encodings::logical::binary::BinaryFieldScheduler; use crate::encodings::logical::list::{ListFieldScheduler, OffsetPageInfo}; use crate::encodings::logical::primitive::PrimitiveFieldScheduler; use crate::encodings::logical::r#struct::{SimpleStructDecoder, SimpleStructScheduler}; @@ -246,7 +245,7 @@ use crate::{BufferScheduler, EncodingsIo}; #[derive(Debug)] pub struct PageInfo { /// The number of rows in the page - pub num_rows: u32, + pub num_rows: u64, /// The encoding that explains the buffers in the page pub encoding: pb::ArrayEncoding, /// The offsets and sizes of the buffers in the file @@ -550,26 +549,12 @@ impl CoreFieldDecoderStrategy { } fn is_primitive(data_type: &DataType) -> bool { - if data_type.is_primitive() { + if data_type.is_primitive() | data_type.is_binary_like() { true - } else if get_str_encoding_type() { - match data_type { - // DataType::is_primitive doesn't consider these primitive but we do - DataType::Boolean - | DataType::Null - | DataType::FixedSizeBinary(_) - | DataType::Utf8 => true, - DataType::FixedSizeList(inner, _) => Self::is_primitive(inner.data_type()), - _ => false, - } } else { match data_type { // DataType::is_primitive doesn't consider these primitive but we do - DataType::Boolean - | DataType::Null - | DataType::FixedSizeBinary(_) - // | DataType::Utf8 - => true, + DataType::Boolean | DataType::Null | DataType::FixedSizeBinary(_) => true, DataType::FixedSizeList(inner, _) => Self::is_primitive(inner.data_type()), _ => false, } @@ -711,28 +696,6 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy { .boxed(); Ok((chain, list_scheduler_fut)) } - DataType::Utf8 | DataType::Binary | DataType::LargeBinary | DataType::LargeUtf8 => { - let list_type = if matches!(data_type, DataType::Utf8 | DataType::Binary) { - DataType::List(Arc::new(ArrowField::new("item", DataType::UInt8, true))) - } else { - DataType::LargeList(Arc::new(ArrowField::new("item", DataType::UInt8, true))) - }; - let list_field = ArrowField::new(&field.name, list_type, true); - let list_field = Field::try_from(&list_field).unwrap(); - // We've changed the data type but are still decoding the same "field" - let (chain, list_decoder) = - chain.restart_at_current(&list_field, column_infos, buffers)?; - let data_type = data_type.clone(); - let binary_scheduler_fut = async move { - let list_decoder = list_decoder.await?; - Ok( - Arc::new(BinaryFieldScheduler::new(list_decoder, data_type.clone())) - as Arc, - ) - } - .boxed(); - Ok((chain, binary_scheduler_fut)) - } DataType::Struct(fields) => { let column_info = column_infos.pop_front().unwrap(); Self::check_simple_struct(&column_info, chain.current_path()).unwrap(); @@ -775,9 +738,9 @@ fn root_column(num_rows: u64) -> ColumnInfo { let root_pages = (0..num_root_pages) .map(|i| PageInfo { num_rows: if i == num_root_pages - 1 { - final_page_num_rows as u32 + final_page_num_rows } else { - u32::MAX + u64::MAX }, encoding: pb::ArrayEncoding { array_encoding: Some(pb::array_encoding::ArrayEncoding::Struct( @@ -861,8 +824,8 @@ impl DecodeBatchScheduler { return; } let next_scan_line = maybe_next_scan_line.unwrap(); - num_rows_scheduled += next_scan_line.rows_scheduled as u64; - rows_to_schedule -= next_scan_line.rows_scheduled as u64; + num_rows_scheduled += next_scan_line.rows_scheduled; + rows_to_schedule -= next_scan_line.rows_scheduled; trace!( "Scheduled scan line of {} rows and {} decoders", next_scan_line.rows_scheduled, @@ -1060,11 +1023,10 @@ impl BatchDecodeStream { return Ok(None); } - let mut to_take = self.rows_remaining.min(self.rows_per_batch as u64) as u32; - self.rows_remaining -= to_take as u64; + let mut to_take = self.rows_remaining.min(self.rows_per_batch as u64); + self.rows_remaining -= to_take; - let scheduled_need = - (self.rows_drained + to_take as u64).saturating_sub(self.rows_scheduled); + let scheduled_need = (self.rows_drained + to_take).saturating_sub(self.rows_scheduled); trace!("scheduled_need = {} because rows_drained = {} and to_take = {} and rows_scheduled = {}", scheduled_need, self.rows_drained, to_take, self.rows_scheduled); if scheduled_need > 0 { let desired_scheduled = scheduled_need + self.rows_scheduled; @@ -1075,7 +1037,7 @@ impl BatchDecodeStream { let actually_scheduled = self.wait_for_scheduled(desired_scheduled).await?; if actually_scheduled < desired_scheduled { let under_scheduled = desired_scheduled - actually_scheduled; - to_take -= under_scheduled as u32; + to_take -= under_scheduled; } } @@ -1083,17 +1045,17 @@ impl BatchDecodeStream { return Ok(None); } - let avail = self.root_decoder.avail_u64(); + let avail = self.root_decoder.avail(); trace!("Top level page has {} rows already available", avail); - if avail < to_take as u64 { + if avail < to_take { trace!( "Top level page waiting for an additional {} rows", - to_take as u64 - avail + to_take - avail ); self.root_decoder.wait(to_take).await?; } let next_task = self.root_decoder.drain(to_take)?; - self.rows_drained += to_take as u64; + self.rows_drained += to_take; Ok(Some(next_task)) } @@ -1125,7 +1087,12 @@ impl BatchDecodeStream { }); next_task.map(|(task, num_rows)| { let task = task.map(|join_wrapper| join_wrapper.unwrap()).boxed(); - let next_task = ReadBatchTask { task, num_rows }; + // This should be true since batch size is u32 + debug_assert!(num_rows <= u32::MAX as u64); + let next_task = ReadBatchTask { + task, + num_rows: num_rows as u32, + }; (next_task, slf) }) }); @@ -1178,8 +1145,8 @@ pub trait PrimitivePageDecoder: Send + Sync { /// * `all_null` - A mutable bool, set to true if a decoder determines all values are null fn decode( &self, - rows_to_skip: u32, - num_rows: u32, + rows_to_skip: u64, + num_rows: u64, all_null: &mut bool, ) -> Result>; fn num_buffers(&self) -> u32; @@ -1193,7 +1160,6 @@ pub trait PrimitivePageDecoder: Send + Sync { /// be shared in follow-up I/O tasks. /// /// See [`crate::decoder`] for more information - pub trait PageScheduler: Send + Sync + std::fmt::Debug { /// Schedules a batch of I/O to load the data needed for the requested ranges /// @@ -1208,7 +1174,7 @@ pub trait PageScheduler: Send + Sync + std::fmt::Debug { /// scheduled. This can be used to assign priority to I/O requests fn schedule_ranges( &self, - ranges: &[Range], + ranges: &[Range], scheduler: &Arc, top_level_row: u64, ) -> BoxFuture<'static, Result>>; @@ -1284,7 +1250,7 @@ impl SchedulerContext { #[derive(Debug)] pub struct ScheduledScanLine { - pub rows_scheduled: u32, + pub rows_scheduled: u64, pub decoders: Vec, } @@ -1366,7 +1332,7 @@ pub struct NextDecodeTask { /// The decode task itself pub task: Box, /// The number of rows that will be created - pub num_rows: u32, + pub num_rows: u64, /// Whether or not the decoder that created this still has more rows to decode pub has_more: bool, } @@ -1434,13 +1400,13 @@ pub trait LogicalPageDecoder: std::fmt::Debug + Send { }) } /// Waits for enough data to be loaded to decode `num_rows` of data - fn wait(&mut self, num_rows: u32) -> BoxFuture>; + fn wait(&mut self, num_rows: u64) -> BoxFuture>; /// Creates a task to decode `num_rows` of data into an array - fn drain(&mut self, num_rows: u32) -> Result; + fn drain(&mut self, num_rows: u64) -> Result; /// The number of rows that are in the page but haven't yet been "waited" - fn unawaited(&self) -> u32; + fn unawaited(&self) -> u64; /// The number of rows that have been "waited" but not yet decoded - fn avail(&self) -> u32; + fn avail(&self) -> u64; /// The data type of the decoded data fn data_type(&self) -> &DataType; } diff --git a/rust/lance-encoding/src/encoder.rs b/rust/lance-encoding/src/encoder.rs index 33c82cede3..ca368f7a9a 100644 --- a/rust/lance-encoding/src/encoder.rs +++ b/rust/lance-encoding/src/encoder.rs @@ -15,8 +15,7 @@ use crate::{ decoder::{ColumnInfo, PageInfo}, encodings::{ logical::{ - binary::BinaryFieldEncoder, list::ListFieldEncoder, primitive::PrimitiveFieldEncoder, - r#struct::StructFieldEncoder, + list::ListFieldEncoder, primitive::PrimitiveFieldEncoder, r#struct::StructFieldEncoder, }, physical::{ basic::BasicEncoder, binary::BinaryEncoder, fixed_size_list::FslEncoder, @@ -102,7 +101,7 @@ pub struct EncodedPage { // The encoded array data pub array: EncodedArray, /// The number of rows in the encoded page - pub num_rows: u32, + pub num_rows: u64, /// The index of the column pub column_idx: u32, } @@ -228,11 +227,6 @@ fn get_compression_scheme() -> CompressionScheme { parse_compression_scheme(&compression_scheme).unwrap_or(CompressionScheme::None) } -pub fn get_str_encoding_type() -> bool { - let str_encoding = std::env::var("LANCE_STR_ARRAY_ENCODING").unwrap_or("none".to_string()); - matches!(str_encoding.as_str(), "binary") -} - impl CoreArrayEncodingStrategy { fn array_encoder_from_type(data_type: &DataType) -> Result> { match data_type { @@ -242,20 +236,14 @@ impl CoreArrayEncodingStrategy { *dimension as u32, ))))) } - DataType::Utf8 => { - if get_str_encoding_type() { - let bin_indices_encoder = Self::array_encoder_from_type(&DataType::UInt64)?; - let bin_bytes_encoder = Self::array_encoder_from_type(&DataType::UInt8)?; - - Ok(Box::new(BinaryEncoder::new( - bin_indices_encoder, - bin_bytes_encoder, - ))) - } else { - Ok(Box::new(BasicEncoder::new(Box::new( - ValueEncoder::try_new(data_type, get_compression_scheme())?, - )))) - } + DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => { + let bin_indices_encoder = Self::array_encoder_from_type(&DataType::UInt64)?; + let bin_bytes_encoder = Self::array_encoder_from_type(&DataType::UInt8)?; + + Ok(Box::new(BinaryEncoder::new( + bin_indices_encoder, + bin_bytes_encoder, + ))) } _ => Ok(Box::new(BasicEncoder::new(Box::new( ValueEncoder::try_new(data_type, get_compression_scheme())?, @@ -369,30 +357,16 @@ impl FieldEncodingStrategy for CoreFieldEncodingStrategy { | DataType::UInt64 | DataType::UInt8 | DataType::FixedSizeBinary(_) - | DataType::FixedSizeList(_, _) => Ok(Box::new(PrimitiveFieldEncoder::try_new( + | DataType::FixedSizeList(_, _) + | DataType::Binary + | DataType::LargeBinary + | DataType::Utf8 + | DataType::LargeUtf8 => Ok(Box::new(PrimitiveFieldEncoder::try_new( cache_bytes_per_column, keep_original_array, self.array_encoding_strategy.clone(), column_index.next_column_index(field.id), )?)), - DataType::Utf8 => { - if get_str_encoding_type() { - Ok(Box::new(PrimitiveFieldEncoder::try_new( - cache_bytes_per_column, - keep_original_array, - self.array_encoding_strategy.clone(), - column_index.next_column_index(field.id), - )?)) - } else { - let list_idx = column_index.next_column_index(field.id); - column_index.skip(); - Ok(Box::new(BinaryFieldEncoder::new( - cache_bytes_per_column, - keep_original_array, - list_idx, - ))) - } - } DataType::List(child) => { let list_idx = column_index.next_column_index(field.id); let inner_encoding = encoding_strategy_root.create_field_encoder( @@ -431,15 +405,6 @@ impl FieldEncodingStrategy for CoreFieldEncodingStrategy { header_idx, ))) } - DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => { - let list_idx = column_index.next_column_index(field.id); - column_index.skip(); - Ok(Box::new(BinaryFieldEncoder::new( - cache_bytes_per_column, - keep_original_array, - list_idx, - ))) - } _ => todo!("Implement encoding for field {}", field), } } diff --git a/rust/lance-encoding/src/encodings/logical.rs b/rust/lance-encoding/src/encodings/logical.rs index adca795788..cad0535571 100644 --- a/rust/lance-encoding/src/encodings/logical.rs +++ b/rust/lance-encoding/src/encodings/logical.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -pub mod binary; pub mod list; pub mod primitive; pub mod r#struct; diff --git a/rust/lance-encoding/src/encodings/logical/binary.rs b/rust/lance-encoding/src/encodings/logical/binary.rs deleted file mode 100644 index 79c3847326..0000000000 --- a/rust/lance-encoding/src/encodings/logical/binary.rs +++ /dev/null @@ -1,296 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -use std::sync::Arc; - -use arrow_array::{ - cast::AsArray, - types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, UInt8Type, Utf8Type}, - Array, ArrayRef, GenericByteArray, GenericListArray, LargeListArray, ListArray, UInt8Array, -}; - -use arrow_buffer::ScalarBuffer; -use arrow_schema::{DataType, Field}; -use futures::future::BoxFuture; -use lance_core::Result; -use log::trace; - -use crate::{ - decoder::{ - DecodeArrayTask, DecoderReady, FieldScheduler, FilterExpression, LogicalPageDecoder, - NextDecodeTask, ScheduledScanLine, SchedulerContext, SchedulingJob, - }, - encoder::{CoreArrayEncodingStrategy, EncodeTask, FieldEncoder}, -}; - -use super::{list::ListFieldEncoder, primitive::PrimitiveFieldEncoder}; - -/// Wraps a varbin scheduler and uses a BinaryPageDecoder to cast -/// the result to the appropriate type -#[derive(Debug)] -pub struct BinarySchedulingJob<'a> { - scheduler: &'a BinaryFieldScheduler, - inner: Box, -} - -impl<'a> SchedulingJob for BinarySchedulingJob<'a> { - fn schedule_next( - &mut self, - context: &mut SchedulerContext, - top_level_row: u64, - ) -> Result { - let inner_scan = self.inner.schedule_next(context, top_level_row)?; - let wrapped_decoders = inner_scan - .decoders - .into_iter() - .map(|decoder| DecoderReady { - path: decoder.path, - decoder: Box::new(BinaryPageDecoder { - inner: decoder.decoder, - data_type: self.scheduler.data_type.clone(), - }), - }) - .collect::>(); - Ok(ScheduledScanLine { - decoders: wrapped_decoders, - rows_scheduled: inner_scan.rows_scheduled, - }) - } - - fn num_rows(&self) -> u64 { - self.inner.num_rows() - } -} - -/// A logical scheduler for utf8/binary pages which assumes the data are encoded as List -#[derive(Debug)] -pub struct BinaryFieldScheduler { - varbin_scheduler: Arc, - data_type: DataType, -} - -impl BinaryFieldScheduler { - // Create a new ListPageScheduler - pub fn new(varbin_scheduler: Arc, data_type: DataType) -> Self { - Self { - varbin_scheduler, - data_type, - } - } -} - -impl FieldScheduler for BinaryFieldScheduler { - fn schedule_ranges<'a>( - &'a self, - ranges: &[std::ops::Range], - filter: &FilterExpression, - ) -> Result> { - trace!("Scheduling binary for {} ranges", ranges.len()); - let varbin_job = self.varbin_scheduler.schedule_ranges(ranges, filter)?; - Ok(Box::new(BinarySchedulingJob { - scheduler: self, - inner: varbin_job, - })) - } - - fn num_rows(&self) -> u64 { - self.varbin_scheduler.num_rows() - } -} - -#[derive(Debug)] -pub struct BinaryPageDecoder { - inner: Box, - data_type: DataType, -} - -impl LogicalPageDecoder for BinaryPageDecoder { - fn wait(&mut self, num_rows: u32) -> BoxFuture> { - self.inner.wait(num_rows) - } - - fn drain(&mut self, num_rows: u32) -> Result { - let inner_task = self.inner.drain(num_rows)?; - Ok(NextDecodeTask { - has_more: inner_task.has_more, - num_rows: inner_task.num_rows, - task: Box::new(BinaryArrayDecoder { - inner: inner_task.task, - data_type: self.data_type.clone(), - }), - }) - } - - fn unawaited(&self) -> u32 { - self.inner.unawaited() - } - - fn avail(&self) -> u32 { - self.inner.avail() - } - - fn data_type(&self) -> &DataType { - &self.data_type - } -} - -pub struct BinaryArrayDecoder { - inner: Box, - data_type: DataType, -} - -impl BinaryArrayDecoder { - fn from_list_array(array: &GenericListArray) -> ArrayRef { - let values = array - .values() - .as_primitive::() - .values() - .inner() - .clone(); - let offsets = array.offsets().clone(); - Arc::new(GenericByteArray::::new( - offsets, - values, - array.nulls().cloned(), - )) - } -} - -impl DecodeArrayTask for BinaryArrayDecoder { - fn decode(self: Box) -> Result { - let data_type = self.data_type; - let arr = self.inner.decode()?; - match data_type { - DataType::Binary => Ok(Self::from_list_array::(arr.as_list::())), - DataType::LargeBinary => Ok(Self::from_list_array::( - arr.as_list::(), - )), - DataType::Utf8 => Ok(Self::from_list_array::(arr.as_list::())), - DataType::LargeUtf8 => Ok(Self::from_list_array::(arr.as_list::())), - _ => panic!("Binary decoder does not support this data type"), - } - } -} - -/// An encoder which encodes string arrays as List -pub struct BinaryFieldEncoder { - varbin_encoder: Box, -} - -impl BinaryFieldEncoder { - pub fn new(cache_bytes_per_column: u64, keep_original_array: bool, column_index: u32) -> Self { - let items_encoder = Box::new( - PrimitiveFieldEncoder::try_new( - cache_bytes_per_column, - keep_original_array, - Arc::new(CoreArrayEncodingStrategy), - column_index + 1, - ) - .unwrap(), - ); - Self { - varbin_encoder: Box::new(ListFieldEncoder::new( - items_encoder, - cache_bytes_per_column, - keep_original_array, - column_index, - )), - } - } - - fn byte_to_list_array>( - array: &GenericByteArray, - ) -> ListArray { - let values = UInt8Array::new( - ScalarBuffer::::new(array.values().clone(), 0, array.values().len()), - None, - ); - let list_field = Field::new("item", DataType::UInt8, true); - ListArray::new( - Arc::new(list_field), - array.offsets().clone(), - Arc::new(values), - array.nulls().cloned(), - ) - } - - fn byte_to_large_list_array>( - array: &GenericByteArray, - ) -> LargeListArray { - let values = UInt8Array::new( - ScalarBuffer::::new(array.values().clone(), 0, array.values().len()), - None, - ); - let list_field = Field::new("item", DataType::UInt8, true); - LargeListArray::new( - Arc::new(list_field), - array.offsets().clone(), - Arc::new(values), - array.nulls().cloned(), - ) - } - - fn to_list_array(array: ArrayRef) -> ArrayRef { - match array.data_type() { - DataType::Utf8 => Arc::new(Self::byte_to_list_array(array.as_string::())), - DataType::LargeUtf8 => { - Arc::new(Self::byte_to_large_list_array(array.as_string::())) - } - DataType::Binary => Arc::new(Self::byte_to_list_array(array.as_binary::())), - DataType::LargeBinary => { - Arc::new(Self::byte_to_large_list_array(array.as_binary::())) - } - _ => panic!("Binary encoder does not support {}", array.data_type()), - } - } -} - -impl FieldEncoder for BinaryFieldEncoder { - fn maybe_encode(&mut self, array: ArrayRef) -> Result> { - let list_array = Self::to_list_array(array); - self.varbin_encoder.maybe_encode(Arc::new(list_array)) - } - - fn flush(&mut self) -> Result> { - self.varbin_encoder.flush() - } - - fn num_columns(&self) -> u32 { - 2 - } - - fn finish(&mut self) -> BoxFuture<'_, Result>> { - self.varbin_encoder.finish() - } -} - -#[cfg(test)] -mod tests { - use arrow_schema::{DataType, Field}; - - use crate::testing::check_round_trip_encoding_random; - - #[test_log::test(tokio::test)] - async fn test_utf8() { - let field = Field::new("", DataType::Utf8, false); - check_round_trip_encoding_random(field).await; - } - - #[test_log::test(tokio::test)] - async fn test_binary() { - let field = Field::new("", DataType::Binary, false); - check_round_trip_encoding_random(field).await; - } - - #[test_log::test(tokio::test)] - async fn test_large_binary() { - let field = Field::new("", DataType::LargeBinary, true); - check_round_trip_encoding_random(field).await; - } - - #[test_log::test(tokio::test)] - async fn test_large_utf8() { - let field = Field::new("", DataType::LargeUtf8, true); - check_round_trip_encoding_random(field).await; - } -} diff --git a/rust/lance-encoding/src/encodings/logical/list.rs b/rust/lance-encoding/src/encodings/logical/list.rs index 48bfcfdd8b..fd7d290ef8 100644 --- a/rust/lance-encoding/src/encodings/logical/list.rs +++ b/rust/lance-encoding/src/encodings/logical/list.rs @@ -97,9 +97,9 @@ impl ListRequestsIter { let mut range = range.clone(); // Skip any offsets pages that are before the range - while offsets_offset + (cur_page_info.offsets_in_page as u64) <= range.start { + while offsets_offset + (cur_page_info.offsets_in_page) <= range.start { trace!("Skipping null offset adjustment chunk {:?}", offsets_offset); - offsets_offset += cur_page_info.offsets_in_page as u64; + offsets_offset += cur_page_info.offsets_in_page; items_offset += cur_page_info.num_items_referenced_by_page; cur_page_info = page_infos_iter.next().unwrap(); } @@ -118,7 +118,7 @@ impl ListRequestsIter { while !range.is_empty() { // The end of the list request is the min of the end of the range // and the end of the current page - let end = offsets_offset + cur_page_info.offsets_in_page as u64; + let end = offsets_offset + cur_page_info.offsets_in_page; let last = end >= range.end; let end = end.min(range.end); list_requests.push_back(ListRequest { @@ -133,7 +133,7 @@ impl ListRequestsIter { // If there is still more data in the range, we need to move to the // next page if !last { - offsets_offset += cur_page_info.offsets_in_page as u64; + offsets_offset += cur_page_info.offsets_in_page; items_offset += cur_page_info.num_items_referenced_by_page; cur_page_info = page_infos_iter.next().unwrap(); } @@ -146,7 +146,7 @@ impl ListRequestsIter { } // Given a page of offset data, grab the corresponding list requests - fn next(&mut self, mut num_offsets: u32) -> Vec { + fn next(&mut self, mut num_offsets: u64) -> Vec { let mut list_requests = Vec::new(); while num_offsets > 0 { let req = self.list_requests.front_mut().unwrap(); @@ -155,12 +155,12 @@ impl ListRequestsIter { num_offsets -= 1; debug_assert_ne!(num_offsets, 0); } - if num_offsets as u64 >= req.num_lists { - num_offsets -= req.num_lists as u32; + if num_offsets >= req.num_lists { + num_offsets -= req.num_lists; list_requests.push(self.list_requests.pop_front().unwrap()); } else { let sub_req = ListRequest { - num_lists: num_offsets as u64, + num_lists: num_offsets, includes_extra_offset: req.includes_extra_offset, null_offset_adjustment: req.null_offset_adjustment, items_offset: req.items_offset, @@ -168,7 +168,7 @@ impl ListRequestsIter { list_requests.push(sub_req); req.includes_extra_offset = false; - req.num_lists -= num_offsets as u64; + req.num_lists -= num_offsets; num_offsets = 0; } } @@ -430,7 +430,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> { debug_assert!(list_reqs .iter() .all(|req| req.null_offset_adjustment == null_offset_adjustment)); - let num_rows = list_reqs.iter().map(|req| req.num_lists).sum::() as u32; + let num_rows = list_reqs.iter().map(|req| req.num_lists).sum::(); // offsets is a uint64 which is guaranteed to create one decoder on each call to schedule_next let next_offsets_decoder = next_offsets.decoders.into_iter().next().unwrap().decoder; @@ -502,7 +502,7 @@ pub struct ListFieldScheduler { /// This is needed to construct the scheduler #[derive(Debug)] pub struct OffsetPageInfo { - pub offsets_in_page: u32, + pub offsets_in_page: u64, pub null_offset_adjustment: u64, pub num_items_referenced_by_page: u64, } @@ -570,9 +570,9 @@ struct ListPageDecoder { offsets: Vec, validity: BooleanBuffer, item_decoder: Option, - lists_available: u32, - num_rows: u32, - rows_drained: u32, + lists_available: u64, + num_rows: u64, + rows_drained: u64, items_type: DataType, offset_type: DataType, data_type: DataType, @@ -642,7 +642,7 @@ impl DecodeArrayTask for ListDecodeTask { } impl LogicalPageDecoder for ListPageDecoder { - fn wait(&mut self, num_rows: u32) -> BoxFuture> { + fn wait(&mut self, num_rows: u64) -> BoxFuture> { async move { // wait for the indirect I/O to finish, run the scheduler for the indirect // I/O and then wait for enough items to arrive @@ -679,7 +679,7 @@ impl LogicalPageDecoder for ListPageDecoder { self.offsets[offset_wait_start as usize + num_rows as usize] - item_start; if items_needed > 0 { // First discount any already available items - let items_already_available = self.item_decoder.as_mut().unwrap().avail_u64(); + let items_already_available = self.item_decoder.as_mut().unwrap().avail(); trace!( "List's items decoder needs {} items and already has {} items available", items_needed, @@ -687,7 +687,7 @@ impl LogicalPageDecoder for ListPageDecoder { ); items_needed = items_needed.saturating_sub(items_already_available); if items_needed > 0 { - self.item_decoder.as_mut().unwrap().wait_u64(items_needed).await?; + self.item_decoder.as_mut().unwrap().wait(items_needed).await?; } } // This is technically undercounting a little. It's possible that we loaded a big items @@ -700,11 +700,11 @@ impl LogicalPageDecoder for ListPageDecoder { .boxed() } - fn unawaited(&self) -> u32 { + fn unawaited(&self) -> u64 { self.num_rows - self.lists_available - self.rows_drained } - fn drain(&mut self, num_rows: u32) -> Result { + fn drain(&mut self, num_rows: u64) -> Result { self.lists_available -= num_rows; // We already have the offsets but need to drain the item pages let mut actual_num_rows = num_rows; @@ -745,7 +745,7 @@ impl LogicalPageDecoder for ListPageDecoder { } else { self.item_decoder .as_mut() - .map(|item_decoder| Result::Ok(item_decoder.drain_u64(num_items_to_drain)?.task)) + .map(|item_decoder| Result::Ok(item_decoder.drain(num_items_to_drain)?.task)) .transpose()? }; @@ -763,7 +763,7 @@ impl LogicalPageDecoder for ListPageDecoder { }) } - fn avail(&self) -> u32 { + fn avail(&self) -> u64 { self.lists_available } @@ -878,7 +878,7 @@ impl ListOffsetsEncoder { tokio::task::spawn(async move { let num_rows = offset_arrays.iter().map(|arr| arr.len()).sum::() - offset_arrays.len(); - let num_rows = num_rows as u32; + let num_rows = num_rows as u64; let mut buffer_index = 0; let array = Self::do_encode( offset_arrays, @@ -1012,7 +1012,7 @@ impl ListOffsetsEncoder { fn do_encode_u64( offset_arrays: Vec, validity: Vec>, - num_offsets: u32, + num_offsets: u64, null_offset_adjustment: u64, buffer_index: &mut u32, inner_encoder: Arc, @@ -1035,7 +1035,7 @@ impl ListOffsetsEncoder { offset_arrays: Vec, validity_arrays: Vec, buffer_index: &mut u32, - num_offsets: u32, + num_offsets: u64, inner_encoder: Arc, ) -> Result { let validity_arrays = validity_arrays diff --git a/rust/lance-encoding/src/encodings/logical/primitive.rs b/rust/lance-encoding/src/encodings/logical/primitive.rs index 1f1457f6f2..44dab86ffd 100644 --- a/rust/lance-encoding/src/encodings/logical/primitive.rs +++ b/rust/lance-encoding/src/encodings/logical/primitive.rs @@ -6,15 +6,17 @@ use std::{fmt::Debug, ops::Range, sync::Arc}; use arrow_array::{ new_null_array, types::{ - ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + ArrowPrimitiveType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + DurationSecondType, Float16Type, Float32Type, Float64Type, GenericBinaryType, + GenericStringType, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, }, - ArrayRef, BooleanArray, FixedSizeBinaryArray, FixedSizeListArray, PrimitiveArray, StringArray, + ArrayRef, BooleanArray, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, + PrimitiveArray, }; use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_schema::{DataType, IntervalUnit, TimeUnit}; @@ -39,7 +41,7 @@ use crate::{ #[derive(Debug)] struct PrimitivePage { scheduler: Box, - num_rows: u32, + num_rows: u64, } /// A field scheduler for primitive fields @@ -67,14 +69,15 @@ impl PrimitiveFieldScheduler { column_buffers: buffers, positions_and_sizes: &page.buffer_offsets_and_sizes, }; - let scheduler = decoder_from_array_encoding(&page.encoding, &page_buffers); + let scheduler = + decoder_from_array_encoding(&page.encoding, &page_buffers, &data_type); PrimitivePage { scheduler, num_rows: page.num_rows, } }) .collect::>(); - let num_rows = page_schedulers.iter().map(|p| p.num_rows as u64).sum(); + let num_rows = page_schedulers.iter().map(|p| p.num_rows).sum(); Self { data_type, page_schedulers, @@ -124,8 +127,8 @@ impl<'a> SchedulingJob for PrimitiveFieldSchedulingJob<'a> { cur_page.num_rows ); // Skip entire pages until we have some overlap with our next range - while cur_page.num_rows as u64 + self.global_row_offset <= range.start { - self.global_row_offset += cur_page.num_rows as u64; + while cur_page.num_rows + self.global_row_offset <= range.start { + self.global_row_offset += cur_page.num_rows; self.page_idx += 1; trace!("Skipping entire page of {} rows", cur_page.num_rows); cur_page = &self.scheduler.page_schedulers[self.page_idx]; @@ -135,14 +138,14 @@ impl<'a> SchedulingJob for PrimitiveFieldSchedulingJob<'a> { // until we find a range that exceeds the current page let mut ranges_in_page = Vec::new(); - while cur_page.num_rows as u64 + self.global_row_offset > range.start { + while cur_page.num_rows + self.global_row_offset > range.start { range.start = range.start.max(self.global_row_offset); let start_in_page = range.start - self.global_row_offset; let end_in_page = start_in_page + (range.end - range.start); - let end_in_page = end_in_page.min(cur_page.num_rows as u64) as u32; - let last_in_range = (end_in_page as u64 + self.global_row_offset) >= range.end; + let end_in_page = end_in_page.min(cur_page.num_rows); + let last_in_range = (end_in_page + self.global_row_offset) >= range.end; - ranges_in_page.push(start_in_page as u32..end_in_page); + ranges_in_page.push(start_in_page..end_in_page); if last_in_range { self.range_idx += 1; if self.range_idx == self.ranges.len() { @@ -162,7 +165,7 @@ impl<'a> SchedulingJob for PrimitiveFieldSchedulingJob<'a> { cur_page.num_rows ); - self.global_row_offset += cur_page.num_rows as u64; + self.global_row_offset += cur_page.num_rows; self.page_idx += 1; let physical_decoder = @@ -213,15 +216,15 @@ pub struct PrimitiveFieldDecoder { data_type: DataType, unloaded_physical_decoder: Option>>>, physical_decoder: Option>, - num_rows: u32, - rows_drained: u32, + num_rows: u64, + rows_drained: u64, } impl PrimitiveFieldDecoder { pub fn new_from_data( physical_decoder: Arc, data_type: DataType, - num_rows: u32, + num_rows: u64, ) -> Self { Self { data_type, @@ -244,8 +247,8 @@ impl Debug for PrimitiveFieldDecoder { } struct PrimitiveFieldDecodeTask { - rows_to_skip: u32, - rows_to_take: u32, + rows_to_skip: u64, + rows_to_take: u64, physical_decoder: Arc, data_type: DataType, } @@ -276,7 +279,7 @@ impl PrimitiveFieldDecodeTask { // into a primitive array is pretty fundamental. fn new_primitive_array( buffers: Vec, - num_rows: u32, + num_rows: u64, data_type: &DataType, ) -> ArrayRef { let mut buffer_iter = buffers.into_iter(); @@ -302,7 +305,50 @@ impl PrimitiveFieldDecodeTask { ) } - fn bytes_to_validity(bytes: BytesMut, num_rows: u32) -> Option { + fn new_generic_byte_array(buffers: Vec, num_rows: u64) -> ArrayRef { + // iterate over buffers to get offsets and then bytes + let mut buffer_iter = buffers.into_iter(); + + let null_buffer = buffer_iter.next().unwrap(); + let null_buffer = if null_buffer.is_empty() { + None + } else { + let null_buffer = null_buffer.freeze().into(); + Some(NullBuffer::new(BooleanBuffer::new( + Buffer::from_bytes(null_buffer), + 0, + num_rows as usize, + ))) + }; + + let indices_bytes = buffer_iter.next().unwrap().freeze(); + let indices_buffer = Buffer::from_bytes(indices_bytes.into()); + let indices_buffer = + ScalarBuffer::::new(indices_buffer, 0, num_rows as usize + 1); + + let offsets = OffsetBuffer::new(indices_buffer.clone()); + + // TODO - add NULL support + // Decoding the bytes creates 2 buffers, the first one is empty due to nulls. + buffer_iter.next().unwrap(); + + let bytes_buffer = buffer_iter.next().unwrap().freeze(); + let bytes_buffer = Buffer::from_bytes(bytes_buffer.into()); + let bytes_buffer_len = bytes_buffer.len(); + let bytes_buffer = ScalarBuffer::::new(bytes_buffer, 0, bytes_buffer_len); + + let bytes_array = Arc::new( + PrimitiveArray::::new(bytes_buffer, None).with_data_type(DataType::UInt8), + ); + + Arc::new(GenericByteArray::::new( + offsets, + bytes_array.values().into(), + null_buffer, + )) + } + + fn bytes_to_validity(bytes: BytesMut, num_rows: u64) -> Option { if bytes.is_empty() { None } else { @@ -318,7 +364,7 @@ impl PrimitiveFieldDecodeTask { fn primitive_array_from_buffers( data_type: &DataType, buffers: Vec, - num_rows: u32, + num_rows: u64, ) -> Result { match data_type { DataType::Boolean => { @@ -463,7 +509,7 @@ impl PrimitiveFieldDecodeTask { let items_array = Self::primitive_array_from_buffers( items.data_type(), remaining_buffers, - num_rows * (*dimension as u32), + num_rows * (*dimension as u64), )?; Ok(Arc::new(FixedSizeListArray::new( items.clone(), @@ -472,36 +518,18 @@ impl PrimitiveFieldDecodeTask { fsl_nulls, ))) } - DataType::Utf8 => { - // iterate over buffers to get offsets and then bytes - let mut buffer_iter = buffers.into_iter(); - let indices_bytes = buffer_iter.next().unwrap().freeze(); - let indices_buffer = Buffer::from_bytes(indices_bytes.into()); - let indices_buffer = - ScalarBuffer::::new(indices_buffer, 0, num_rows as usize + 1); - - let offsets = OffsetBuffer::new(indices_buffer.clone()); - - // TODO - add NULL support - // Decoding the bytes creates 2 buffers, the first one is empty due to nulls. - let _null_buffer = buffer_iter.next().unwrap(); - - let bytes_buffer = buffer_iter.next().unwrap().freeze(); - let bytes_buffer = Buffer::from_bytes(bytes_buffer.into()); - let bytes_buffer_len = bytes_buffer.len(); - let bytes_buffer = ScalarBuffer::::new(bytes_buffer, 0, bytes_buffer_len); - - let bytes_array = Arc::new( - PrimitiveArray::::new(bytes_buffer, None) - .with_data_type(DataType::UInt8), - ); - - Ok(Arc::new(StringArray::new( - offsets, - bytes_array.values().into(), - None, - ))) - } + DataType::Utf8 => Ok(Self::new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::LargeUtf8 => Ok(Self::new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::Binary => Ok(Self::new_generic_byte_array::>( + buffers, num_rows, + )), + DataType::LargeBinary => Ok(Self::new_generic_byte_array::>( + buffers, num_rows, + )), _ => Err(Error::io( format!( "The data type {} cannot be decoded from a primitive encoding", @@ -516,7 +544,7 @@ impl PrimitiveFieldDecodeTask { impl LogicalPageDecoder for PrimitiveFieldDecoder { // TODO: In the future, at some point, we may consider partially waiting for primitive pages by // breaking up large I/O into smaller I/O as a way to accelerate the "time-to-first-decode" - fn wait(&mut self, _: u32) -> BoxFuture> { + fn wait(&mut self, _: u64) -> BoxFuture> { async move { let physical_decoder = self.unloaded_physical_decoder.take().unwrap().await?; self.physical_decoder = Some(Arc::from(physical_decoder)); @@ -525,7 +553,7 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder { .boxed() } - fn drain(&mut self, num_rows: u32) -> Result { + fn drain(&mut self, num_rows: u64) -> Result { let rows_to_skip = self.rows_drained; let rows_to_take = num_rows; @@ -545,7 +573,7 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder { }) } - fn unawaited(&self) -> u32 { + fn unawaited(&self) -> u64 { if self.unloaded_physical_decoder.is_some() { self.num_rows } else { @@ -553,7 +581,7 @@ impl LogicalPageDecoder for PrimitiveFieldDecoder { } } - fn avail(&self) -> u32 { + fn avail(&self) -> u64 { if self.unloaded_physical_decoder.is_some() { 0 } else { @@ -664,7 +692,7 @@ impl PrimitiveFieldEncoder { let column_idx = self.column_index; Ok(tokio::task::spawn(async move { - let num_rows = arrays.iter().map(|arr| arr.len() as u32).sum(); + let num_rows = arrays.iter().map(|arr| arr.len() as u64).sum(); let mut buffer_index = 0; let array = encoder.encode(&arrays, &mut buffer_index)?; Ok(EncodedPage { diff --git a/rust/lance-encoding/src/encodings/logical/struct.rs b/rust/lance-encoding/src/encodings/logical/struct.rs index 8fb5731469..6dc2609886 100644 --- a/rust/lance-encoding/src/encodings/logical/struct.rs +++ b/rust/lance-encoding/src/encodings/logical/struct.rs @@ -129,14 +129,14 @@ impl<'a> SchedulingJob for SimpleStructSchedulerJob<'a> { child_scan.rows_scheduled, next_child.col_idx ); - next_child.rows_scheduled += child_scan.rows_scheduled as u64; - next_child.rows_remaining -= child_scan.rows_scheduled as u64; + next_child.rows_scheduled += child_scan.rows_scheduled; + next_child.rows_remaining -= child_scan.rows_scheduled; decoders.extend(child_scan.decoders); self.children.push(next_child); self.rows_scheduled = self.children.peek().unwrap().rows_scheduled; context = scoped.pop(); } - let struct_rows_scheduled = (self.rows_scheduled - old_rows_scheduled) as u32; + let struct_rows_scheduled = self.rows_scheduled - old_rows_scheduled; Ok(ScheduledScanLine { decoders, rows_scheduled: struct_rows_scheduled, @@ -162,8 +162,6 @@ impl<'a> SchedulingJob for SimpleStructSchedulerJob<'a> { pub struct SimpleStructScheduler { children: Vec>, child_fields: Fields, - // A single page cannot contain more than u32 rows. However, we also use SimpleStructScheduler - // at the top level and a single file *can* contain more than u32 rows. num_rows: u64, } @@ -230,7 +228,7 @@ struct ChildState { struct CompositeDecodeTask { // One per child tasks: Vec>, - num_rows: u32, + num_rows: u64, has_more: bool, } @@ -273,7 +271,7 @@ impl ChildState { let mut remaining = num_rows.saturating_sub(self.rows_available); for next_decoder in &mut self.scheduled { if next_decoder.unawaited() > 0 { - let rows_to_wait = remaining.min(next_decoder.unawaited() as u64) as u32; + let rows_to_wait = remaining.min(next_decoder.unawaited()); trace!( "Struct await an additional {} rows from the current page", rows_to_wait @@ -289,14 +287,14 @@ impl ChildState { next_decoder.wait(rows_to_wait).await?; let newly_avail = next_decoder.avail() - previously_avail; trace!("The await loaded {} rows", newly_avail); - self.rows_available += newly_avail as u64; + self.rows_available += newly_avail; // Need to use saturating_sub here because we might have asked for range // 0-1000 and this page we just loaded might cover 900-1100 and so newly_avail // is 200 but rows_unawaited is only 100 // // TODO: Unit tests may not be covering this branch right now - self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail as u64); - remaining -= rows_to_wait as u64; + self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail); + remaining -= rows_to_wait; if remaining == 0 { break; } @@ -322,13 +320,13 @@ impl ChildState { }; while remaining > 0 { let next = self.scheduled.front_mut().unwrap(); - let rows_to_take = remaining.min(next.avail() as u64) as u32; + let rows_to_take = remaining.min(next.avail()); let next_task = next.drain(rows_to_take)?; if next.avail() == 0 && next.unawaited() == 0 { trace!("Completely drained page"); self.scheduled.pop_front(); } - remaining -= rows_to_take as u64; + remaining -= rows_to_take; composite.tasks.push(next_task.task); composite.num_rows += next_task.num_rows; } @@ -357,53 +355,6 @@ impl SimpleStructDecoder { data_type, } } - - pub fn avail_u64(&self) -> u64 { - self.children - .iter() - .map(|c| c.rows_available) - .min() - .unwrap() - } - - // Rows are unawaited if they are unawaited in any child column - pub fn unawaited_u64(&self) -> u64 { - self.children - .iter() - .map(|c| c.rows_unawaited) - .max() - .unwrap() - } - - pub fn wait_u64(&mut self, num_rows: u64) -> BoxFuture> { - async move { - for child in self.children.iter_mut() { - child.wait(num_rows).await?; - } - Ok(()) - } - .boxed() - } - - pub fn drain_u64(&mut self, num_rows: u64) -> Result { - let child_tasks = self - .children - .iter_mut() - .map(|child| child.drain(num_rows)) - .collect::>>()?; - let num_rows = child_tasks[0].num_rows; - let has_more = child_tasks[0].has_more; - debug_assert!(child_tasks.iter().all(|task| task.num_rows == num_rows)); - debug_assert!(child_tasks.iter().all(|task| task.has_more == has_more)); - Ok(NextDecodeTask { - task: Box::new(SimpleStructDecodeTask { - children: child_tasks, - child_fields: self.child_fields.clone(), - }), - num_rows, - has_more, - }) - } } impl LogicalPageDecoder for SimpleStructDecoder { @@ -423,21 +374,21 @@ impl LogicalPageDecoder for SimpleStructDecoder { Ok(()) } - fn wait(&mut self, num_rows: u32) -> BoxFuture> { + fn wait(&mut self, num_rows: u64) -> BoxFuture> { async move { for child in self.children.iter_mut() { - child.wait(num_rows as u64).await?; + child.wait(num_rows).await?; } Ok(()) } .boxed() } - fn drain(&mut self, num_rows: u32) -> Result { + fn drain(&mut self, num_rows: u64) -> Result { let child_tasks = self .children .iter_mut() - .map(|child| child.drain(num_rows as u64)) + .map(|child| child.drain(num_rows)) .collect::>>()?; let num_rows = child_tasks[0].num_rows; let has_more = child_tasks[0].has_more; @@ -454,17 +405,21 @@ impl LogicalPageDecoder for SimpleStructDecoder { } // Rows are available only if they are available in every child column - fn avail(&self) -> u32 { - let avail = self.avail_u64(); - debug_assert!(avail <= u32::MAX as u64); - avail as u32 + fn avail(&self) -> u64 { + self.children + .iter() + .map(|c| c.rows_available) + .min() + .unwrap() } // Rows are unawaited if they are unawaited in any child column - fn unawaited(&self) -> u32 { - let unawaited = self.unawaited_u64(); - debug_assert!(unawaited <= u32::MAX as u64); - unawaited as u32 + fn unawaited(&self) -> u64 { + self.children + .iter() + .map(|c| c.rows_unawaited) + .max() + .unwrap() } fn data_type(&self) -> &DataType { @@ -495,7 +450,7 @@ impl DecodeArrayTask for SimpleStructDecodeTask { pub struct StructFieldEncoder { children: Vec>, column_index: u32, - num_rows_seen: u32, + num_rows_seen: u64, } impl StructFieldEncoder { @@ -511,7 +466,7 @@ impl StructFieldEncoder { impl FieldEncoder for StructFieldEncoder { fn maybe_encode(&mut self, array: ArrayRef) -> Result> { - self.num_rows_seen += array.len() as u32; + self.num_rows_seen += array.len() as u64; let struct_array = array.as_struct(); let child_tasks = self .children diff --git a/rust/lance-encoding/src/encodings/physical.rs b/rust/lance-encoding/src/encodings/physical.rs index 0c963d4ae5..e3895daefe 100644 --- a/rust/lance-encoding/src/encodings/physical.rs +++ b/rust/lance-encoding/src/encodings/physical.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use arrow_schema::DataType; + use crate::encodings::physical::value::CompressionScheme; use crate::{decoder::PageScheduler, format::pb}; @@ -79,19 +81,30 @@ fn get_buffer_decoder(encoding: &pb::Flat, buffers: &PageBuffers) -> Box Box { match encoding.array_encoding.as_ref().unwrap() { pb::array_encoding::ArrayEncoding::Nullable(basic) => { match basic.nullability.as_ref().unwrap() { - pb::nullable::Nullability::NoNulls(no_nulls) => { - Box::new(BasicPageScheduler::new_non_nullable( - decoder_from_array_encoding(no_nulls.values.as_ref().unwrap(), buffers), - )) - } + pb::nullable::Nullability::NoNulls(no_nulls) => Box::new( + BasicPageScheduler::new_non_nullable(decoder_from_array_encoding( + no_nulls.values.as_ref().unwrap(), + buffers, + data_type, + )), + ), pb::nullable::Nullability::SomeNulls(some_nulls) => { Box::new(BasicPageScheduler::new_nullable( - decoder_from_array_encoding(some_nulls.validity.as_ref().unwrap(), buffers), - decoder_from_array_encoding(some_nulls.values.as_ref().unwrap(), buffers), + decoder_from_array_encoding( + some_nulls.validity.as_ref().unwrap(), + buffers, + data_type, + ), + decoder_from_array_encoding( + some_nulls.values.as_ref().unwrap(), + buffers, + data_type, + ), )) } pb::nullable::Nullability::AllNulls(_) => { @@ -102,7 +115,7 @@ pub fn decoder_from_array_encoding( pb::array_encoding::ArrayEncoding::Flat(flat) => get_buffer_decoder(flat, buffers), pb::array_encoding::ArrayEncoding::FixedSizeList(fixed_size_list) => { let item_encoding = fixed_size_list.items.as_ref().unwrap(); - let item_scheduler = decoder_from_array_encoding(item_encoding, buffers); + let item_scheduler = decoder_from_array_encoding(item_encoding, buffers, data_type); Box::new(FixedListScheduler::new( item_scheduler, fixed_size_list.dimension, @@ -112,18 +125,26 @@ pub fn decoder_from_array_encoding( // since we know it is a list based on the schema. In the future there may be different ways // of storing the list offsets. pb::array_encoding::ArrayEncoding::List(list) => { - decoder_from_array_encoding(list.offsets.as_ref().unwrap(), buffers) + decoder_from_array_encoding(list.offsets.as_ref().unwrap(), buffers, data_type) } pb::array_encoding::ArrayEncoding::Binary(binary) => { let indices_encoding = binary.indices.as_ref().unwrap(); let bytes_encoding = binary.bytes.as_ref().unwrap(); - let indices_scheduler = decoder_from_array_encoding(indices_encoding, buffers); - let bytes_scheduler = decoder_from_array_encoding(bytes_encoding, buffers); + let indices_scheduler = + decoder_from_array_encoding(indices_encoding, buffers, data_type); + let bytes_scheduler = decoder_from_array_encoding(bytes_encoding, buffers, data_type); + + let offset_type = match data_type { + DataType::LargeBinary | DataType::LargeUtf8 => DataType::Int64, + _ => DataType::Int32, + }; Box::new(BinaryPageScheduler::new( indices_scheduler.into(), bytes_scheduler.into(), + offset_type, + binary.null_adjustment, )) } // Currently there is no way to encode struct nullability and structs are encoded with a "header" column diff --git a/rust/lance-encoding/src/encodings/physical/basic.rs b/rust/lance-encoding/src/encodings/physical/basic.rs index 87c466eada..3dccc2cf82 100644 --- a/rust/lance-encoding/src/encodings/physical/basic.rs +++ b/rust/lance-encoding/src/encodings/physical/basic.rs @@ -122,7 +122,7 @@ impl BasicPageScheduler { impl PageScheduler for BasicPageScheduler { fn schedule_ranges( &self, - ranges: &[std::ops::Range], + ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, ) -> BoxFuture<'static, Result>> { @@ -171,8 +171,8 @@ struct BasicPageDecoder { impl PrimitivePageDecoder for BasicPageDecoder { fn decode( &self, - rows_to_skip: u32, - num_rows: u32, + rows_to_skip: u64, + num_rows: u64, all_null: &mut bool, ) -> Result> { let dest_buffers = match &self.mode { diff --git a/rust/lance-encoding/src/encodings/physical/binary.rs b/rust/lance-encoding/src/encodings/physical/binary.rs index 6d2c47d704..f593b35386 100644 --- a/rust/lance-encoding/src/encodings/physical/binary.rs +++ b/rust/lance-encoding/src/encodings/physical/binary.rs @@ -5,16 +5,12 @@ use core::panic; use std::sync::Arc; use arrow_array::cast::AsArray; -use arrow_array::types::UInt32Type; -// use arrow::compute::concat; -use arrow_array::{ - builder::{ArrayBuilder, Int32Builder, UInt32Builder, UInt8Builder}, - Array, ArrayRef, Int32Array, UInt32Array, -}; +use arrow_array::types::UInt64Type; +use arrow_array::{Array, ArrayRef}; +use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, ScalarBuffer}; use bytes::BytesMut; use futures::stream::StreamExt; use futures::{future::BoxFuture, stream::FuturesOrdered, FutureExt}; -// use rand::seq::index; use crate::{ decoder::{PageScheduler, PrimitivePageDecoder}, @@ -26,33 +22,97 @@ use crate::{ use crate::decoder::LogicalPageDecoder; use crate::encodings::logical::primitive::PrimitiveFieldDecoder; -use arrow_array::PrimitiveArray; +use arrow_array::{PrimitiveArray, UInt64Array, UInt8Array}; use arrow_schema::DataType; use lance_core::Result; -use std::ops::Deref; + +struct IndicesNormalizer { + indices: Vec, + validity: BooleanBufferBuilder, + null_adjustment: u64, +} + +impl IndicesNormalizer { + fn new(num_rows: u64, null_adjustment: u64) -> Self { + let mut indices = Vec::with_capacity(num_rows as usize); + indices.push(0); + Self { + indices, + validity: BooleanBufferBuilder::new(num_rows as usize), + null_adjustment, + } + } + + fn normalize(&self, val: u64) -> (bool, u64) { + if val >= self.null_adjustment { + (false, val - self.null_adjustment) + } else { + (true, val) + } + } + + fn extend(&mut self, new_indices: &PrimitiveArray, is_start: bool) { + let mut last = *self.indices.last().unwrap(); + if is_start { + let (is_valid, val) = self.normalize(new_indices.value(0)); + self.indices.push(val); + self.validity.append(is_valid); + last += val; + } + let mut prev = self.normalize(*new_indices.values().first().unwrap()).1; + for w in new_indices.values().windows(2) { + let (is_valid, val) = self.normalize(w[1]); + let next = val - prev + last; + self.indices.push(next); + self.validity.append(is_valid); + prev = val; + last = next; + } + } + + fn into_parts(mut self) -> (Vec, BooleanBuffer) { + (self.indices, self.validity.finish()) + } +} #[derive(Debug)] pub struct BinaryPageScheduler { indices_scheduler: Arc, bytes_scheduler: Arc, + offsets_type: DataType, + null_adjustment: u64, } impl BinaryPageScheduler { pub fn new( indices_scheduler: Arc, bytes_scheduler: Arc, + offsets_type: DataType, + null_adjustment: u64, ) -> Self { Self { indices_scheduler, bytes_scheduler, + offsets_type, + null_adjustment, } } } +impl BinaryPageScheduler { + fn decode_indices(decoder: Arc, num_rows: u64) -> Result { + let mut primitive_wrapper = + PrimitiveFieldDecoder::new_from_data(decoder, DataType::UInt64, num_rows); + let drained_task = primitive_wrapper.drain(num_rows)?; + let indices_decode_task = drained_task.task; + indices_decode_task.decode() + } +} + impl PageScheduler for BinaryPageScheduler { fn schedule_ranges( &self, - ranges: &[std::ops::Range], + ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, ) -> BoxFuture<'static, Result>> { @@ -69,22 +129,25 @@ impl PageScheduler for BinaryPageScheduler { 0..(range.end) } }) - .collect::>>(); + .collect::>>(); + + let num_rows = ranges.iter().map(|r| r.end - r.start).sum::(); - let mut futures_ordered = FuturesOrdered::new(); - for range in indices_ranges.iter() { - let indices_page_decoder = + let mut futures_ordered = indices_ranges + .iter() + .map(|range| { self.indices_scheduler - .schedule_ranges(&[range.clone()], scheduler, top_level_row); - futures_ordered.push_back(indices_page_decoder); - } + .schedule_ranges(&[range.clone()], scheduler, top_level_row) + }) + .collect::>(); let ranges = ranges.to_vec(); let copy_scheduler = scheduler.clone(); let copy_bytes_scheduler = self.bytes_scheduler.clone(); - let copy_indices_ranges = indices_ranges.to_vec(); + let null_adjustment = self.null_adjustment; + let offsets_type = self.offsets_type.clone(); - async move { + tokio::spawn(async move { // For the following data: // "abcd", "hello", "abcd", "apple", "hello", "abcd" // 4, 9, 13, 18, 23, 27 @@ -97,87 +160,63 @@ impl PageScheduler for BinaryPageScheduler { // These are the normalized offsets stored in decoded_indices // Rest of the workflow is continued later in BinaryPageDecoder - let mut builder = UInt32Builder::new(); + let mut indices_builder = IndicesNormalizer::new(num_rows, null_adjustment); let mut bytes_ranges = Vec::new(); let mut curr_range_idx = 0; - let mut last = 0; while let Some(indices_page_decoder) = futures_ordered.next().await { - let indices: Arc = Arc::from(indices_page_decoder?); + let decoder = Arc::from(indices_page_decoder?); // Build and run decode task for offsets - let curr_indices_range = copy_indices_ranges[curr_range_idx].clone(); + let curr_indices_range = indices_ranges[curr_range_idx].clone(); let curr_row_range = ranges[curr_range_idx].clone(); let indices_num_rows = curr_indices_range.end - curr_indices_range.start; - let mut primitive_wrapper = PrimitiveFieldDecoder::new_from_data( - indices, - DataType::UInt32, - indices_num_rows, - ); - let drained_task = primitive_wrapper.drain(indices_num_rows)?; - let indices_decode_task = drained_task.task; - let decoded_part = indices_decode_task.decode()?; - - let indices_array = decoded_part.as_primitive::(); - let mut indices_vec = indices_array.values().to_vec(); - - // Pad a zero at the start if the first row is requested - // This is because the offsets do not start with 0 by default - if curr_row_range.start == 0 { - indices_vec.insert(0, 0); - } - // Normalize the indices as described above - let normalized_indices: PrimitiveArray = indices_vec - .iter() - .map(|x| x - indices_vec[0] + last) - .collect(); - last = normalized_indices.value(normalized_indices.len() - 1); - let normalized_vec = normalized_indices.values().to_vec(); - - // The first vector to be normalized should not have the leading zero removed - let truncated_vec = if curr_range_idx == 0 { - normalized_vec.as_slice() - } else { - &normalized_vec[1..] - }; + let indices = Self::decode_indices(decoder, indices_num_rows)?; + let indices = indices.as_primitive::(); - builder.append_slice(truncated_vec); - - // get bytes range from the index range - let bytes_range = if curr_row_range.start != 0 { - indices_array.value(0)..indices_array.value(indices_array.len() - 1) + let first = if curr_row_range.start == 0 { + 0 } else { - 0..indices_array.value(indices_array.len() - 1) + indices_builder + .normalize(*indices.values().first().unwrap()) + .1 }; + let last = indices_builder + .normalize(*indices.values().last().unwrap()) + .1; + bytes_ranges.push(first..last); + + indices_builder.extend(indices, curr_row_range.start == 0); - bytes_ranges.push(bytes_range); curr_range_idx += 1; } - let decoded_indices = Arc::new(builder.finish()); - - let bytes_ranges_slice = bytes_ranges.as_slice(); + let (indices, validity) = indices_builder.into_parts(); + let decoded_indices = UInt64Array::from(indices); // Schedule the bytes for decoding - let bytes_page_decoder = copy_bytes_scheduler.schedule_ranges( - bytes_ranges_slice, - ©_scheduler, - top_level_row, - ); + let bytes_page_decoder = + copy_bytes_scheduler.schedule_ranges(&bytes_ranges, ©_scheduler, top_level_row); let bytes_decoder: Box = bytes_page_decoder.await?; Ok(Box::new(BinaryPageDecoder { decoded_indices, + validity, + offsets_type, bytes_decoder, }) as Box) - } + }) + // Propagate join panic + .map(|join_handle| join_handle.unwrap()) .boxed() } } struct BinaryPageDecoder { - decoded_indices: Arc, + decoded_indices: UInt64Array, + offsets_type: DataType, + validity: BooleanBuffer, bytes_decoder: Box, } @@ -189,48 +228,95 @@ impl PrimitivePageDecoder for BinaryPageDecoder { // Allocate 8 bytes capacity. // Next rows_to_skip=2, num_rows=1 // Skip 8 bytes. Allocate 5 bytes capacity. + // // The normalized offsets are [0, 4, 8, 13] // We only need [8, 13] to decode in this case. // These need to be normalized in order to build the string later // So return [0, 5] fn decode( &self, - rows_to_skip: u32, - num_rows: u32, + rows_to_skip: u64, + num_rows: u64, all_null: &mut bool, ) -> Result> { - let offsets = self + // Buffers[0] == validity buffer + // Buffers[1] == offsets buffer + // Buffers[2] == null buffer // TODO: Micro-optimization, can we get rid of this? Doesn't hurt much though + // This buffer is always empty since bytes are not allowed to contain nulls + // Buffers[3] == bytes buffer + + // STEP 1: validity buffer + let target_validity = self + .validity + .slice(rows_to_skip as usize, num_rows as usize); + let has_nulls = target_validity.count_set_bits() < target_validity.len(); + + let validity_buffer = if has_nulls { + let num_validity_bits = arrow_buffer::bit_util::ceil(num_rows as usize, 8); + let mut validity_buffer = BytesMut::with_capacity(num_validity_bits); + + if rows_to_skip == 0 { + validity_buffer.extend_from_slice(target_validity.inner().as_slice()); + } else { + // Need to copy the buffer because there may be a bit offset in first byte + let target_validity = BooleanBuffer::from_iter(target_validity.iter()); + validity_buffer.extend_from_slice(target_validity.inner().as_slice()); + } + validity_buffer + } else { + BytesMut::new() + }; + + // STEP 2: offsets buffer + // Currently we always do a copy here, we need to cast to the appropriate type + // and we go ahead and normalize so the starting offset is 0 (though we could skip + // this) + let bytes_per_offset = match self.offsets_type { + DataType::Int32 => 4, + DataType::Int64 => 8, + _ => panic!("Unsupported offsets type"), + }; + + let target_offsets = self .decoded_indices - .as_any() - .downcast_ref::() - .unwrap(); - - let bytes_to_skip = offsets.value(rows_to_skip as usize); - let num_bytes = offsets.value((rows_to_skip + num_rows) as usize) - bytes_to_skip; - let target_offsets = offsets.slice( - rows_to_skip.try_into().unwrap(), - (num_rows + 1).try_into().unwrap(), - ); + .slice(rows_to_skip as usize, (num_rows + 1) as usize); - let mut bytes_buffers = self - .bytes_decoder - .decode(bytes_to_skip, num_bytes, all_null)?; - - // Normalize offsets + // Normalize and cast (TODO: could fuse these into one pass for micro-optimization) let target_vec = target_offsets.values(); - let normalized_array: PrimitiveArray = - target_vec.iter().map(|x| x - target_vec[0]).collect(); - let normalized_values = normalized_array.values(); - - let byte_slice = normalized_values.inner().deref(); - let mut dest_buffers = vec![BytesMut::from(byte_slice)]; - dest_buffers.append(&mut bytes_buffers); + let start = target_vec[0]; + let offsets_buffer = + match bytes_per_offset { + 4 => ScalarBuffer::from_iter(target_vec.iter().map(|x| (x - start) as i32)) + .into_inner(), + 8 => ScalarBuffer::from_iter(target_vec.iter().map(|x| (x - start) as i64)) + .into_inner(), + _ => panic!("Unsupported offsets type"), + }; + // TODO: This forces a second copy, which is unfortunate, try and remove in the future + let offsets_buf = BytesMut::from(offsets_buffer.as_slice()); + + let bytes_to_skip = self.decoded_indices.value(rows_to_skip as usize); + let num_bytes = self + .decoded_indices + .value((rows_to_skip + num_rows) as usize) + - bytes_to_skip; + + let mut output_buffers = vec![validity_buffer, offsets_buf]; + + // Copy decoded bytes into dest_buffers[2..] + // Currently an empty null buffer is the first one + // The actual bytes are in the second buffer + // Including the indices this results in 4 buffers in total + output_buffers.extend( + self.bytes_decoder + .decode(bytes_to_skip, num_bytes, all_null)?, + ); - Ok(dest_buffers) + Ok(output_buffers) } fn num_buffers(&self) -> u32 { - self.bytes_decoder.num_buffers() + 1 + self.bytes_decoder.num_buffers() + 2 } } @@ -256,149 +342,302 @@ impl BinaryEncoder { // Strings are a vector of arrays corresponding to each record batch // Zero offset is removed from the start of the offsets array // The indices array is computed across all arrays in the vector -fn get_indices_from_string_arrays(arrays: &[ArrayRef]) -> ArrayRef { - let mut indices_builder = Int32Builder::new(); - let mut last_offset = 0; - arrays.iter().for_each(|arr| { - let string_arr = arrow_array::cast::as_string_array(arr); - let offsets = string_arr.offsets().inner(); - let mut offsets = offsets.slice(1, offsets.len() - 1).to_vec(); - - if indices_builder.len() == 0 { - last_offset = offsets[offsets.len() - 1]; +fn get_indices_from_string_arrays(arrays: &[ArrayRef]) -> (ArrayRef, u64) { + let num_rows = arrays.iter().map(|arr| arr.len()).sum::(); + let mut indices = Vec::with_capacity(num_rows); + let mut last_offset = 0_u64; + for array in arrays { + if let Some(array) = array.as_string_opt::() { + let offsets = array.offsets().inner(); + indices.extend(offsets.windows(2).map(|w| { + let strlen = (w[1] - w[0]) as u64; + let off = strlen + last_offset; + last_offset = off; + off + })); + } else if let Some(array) = array.as_string_opt::() { + let offsets = array.offsets().inner(); + indices.extend(offsets.windows(2).map(|w| { + let strlen = (w[1] - w[0]) as u64; + let off = strlen + last_offset; + last_offset = off; + off + })); + } else if let Some(array) = array.as_binary_opt::() { + let offsets = array.offsets().inner(); + indices.extend(offsets.windows(2).map(|w| { + let strlen = (w[1] - w[0]) as u64; + let off = strlen + last_offset; + last_offset = off; + off + })); + } else if let Some(array) = array.as_binary_opt::() { + let offsets = array.offsets().inner(); + indices.extend(offsets.windows(2).map(|w| { + let strlen = (w[1] - w[0]) as u64; + let off = strlen + last_offset; + last_offset = off; + off + })); } else { - offsets = offsets - .iter() - .map(|offset| offset + last_offset) - .collect::>(); - last_offset = offsets[offsets.len() - 1]; + panic!("Array is not a string array"); } + } + let last_offset = *indices.last().expect("Indices array is empty"); + // 8 exabytes in a single array seems unlikely but...just in case + assert!( + last_offset < u64::MAX / 2, + "Indices array with strings up to 2^63 is too large for this encoding" + ); + let null_adjustment: u64 = *indices.last().expect("Indices array is empty") + 1; + + let mut indices_offset = 0; + for array in arrays { + if let Some(nulls) = array.nulls() { + let indices_slice = &mut indices[indices_offset..indices_offset + array.len()]; + indices_slice + .iter_mut() + .zip(nulls.iter()) + .for_each(|(index, is_valid)| { + if !is_valid { + *index += null_adjustment; + } + }); + } + indices_offset += array.len(); + } - let new_int_arr = Int32Array::from(offsets); - indices_builder.append_slice(new_int_arr.values()); - }); - - Arc::new(indices_builder.finish()) as ArrayRef + (Arc::new(UInt64Array::from(indices)), null_adjustment) } // Bytes computed across all string arrays, similar to indices above -fn get_bytes_from_string_arrays(arrays: &[ArrayRef]) -> ArrayRef { - let mut bytes_builder = UInt8Builder::new(); - arrays.iter().for_each(|arr| { - let string_arr = arrow_array::cast::as_string_array(arr); - let values = string_arr.values(); - bytes_builder.append_slice(values); - }); - - Arc::new(bytes_builder.finish()) as ArrayRef +fn get_bytes_from_string_arrays(arrays: &[ArrayRef]) -> Vec { + arrays + .iter() + .map(|arr| { + let (values_buffer, start, stop) = if let Some(arr) = arr.as_string_opt::() { + ( + arr.values(), + arr.offsets()[0] as usize, + arr.offsets()[arr.len()] as usize, + ) + } else if let Some(arr) = arr.as_string_opt::() { + ( + arr.values(), + arr.offsets()[0] as usize, + arr.offsets()[arr.len()] as usize, + ) + } else if let Some(arr) = arr.as_binary_opt::() { + ( + arr.values(), + arr.offsets()[0] as usize, + arr.offsets()[arr.len()] as usize, + ) + } else if let Some(arr) = arr.as_binary_opt::() { + ( + arr.values(), + arr.offsets()[0] as usize, + arr.offsets()[arr.len()] as usize, + ) + } else { + panic!("Array is not a string / binary array"); + }; + let values = ScalarBuffer::new(values_buffer.clone(), start, stop - start); + Arc::new(UInt8Array::new(values, None)) as ArrayRef + }) + .collect() } impl ArrayEncoder for BinaryEncoder { fn encode(&self, arrays: &[ArrayRef], buffer_index: &mut u32) -> Result { - let (null_count, _row_count) = arrays - .iter() - .map(|arr| (arr.null_count() as u32, arr.len() as u32)) - .fold((0, 0), |acc, val| (acc.0 + val.0, acc.1 + val.1)); + let (index_array, null_adjustment) = get_indices_from_string_arrays(arrays); + let encoded_indices = self.indices_encoder.encode(&[index_array], buffer_index)?; + + let byte_arrays = get_bytes_from_string_arrays(arrays); + let encoded_bytes = self.bytes_encoder.encode(&byte_arrays, buffer_index)?; + + let mut encoded_buffers = encoded_indices.buffers; + encoded_buffers.extend(encoded_bytes.buffers); + + Ok(EncodedArray { + buffers: encoded_buffers, + encoding: pb::ArrayEncoding { + array_encoding: Some(pb::array_encoding::ArrayEncoding::Binary(Box::new( + pb::Binary { + indices: Some(Box::new(encoded_indices.encoding)), + bytes: Some(Box::new(encoded_bytes.encoding)), + null_adjustment, + }, + ))), + }, + }) + } +} - if null_count != 0 { - panic!("Data contains null values, not currently supported for binary data.") - } else { - let index_array = get_indices_from_string_arrays(arrays); - let encoded_indices = self.indices_encoder.encode(&[index_array], buffer_index)?; - - let byte_array = get_bytes_from_string_arrays(arrays); - let encoded_bytes = self.bytes_encoder.encode(&[byte_array], buffer_index)?; - - let mut encoded_buffers = encoded_indices.buffers; - encoded_buffers.extend(encoded_bytes.buffers); - - Ok(EncodedArray { - buffers: encoded_buffers, - encoding: pb::ArrayEncoding { - array_encoding: Some(pb::array_encoding::ArrayEncoding::Binary(Box::new( - pb::Binary { - indices: Some(Box::new(encoded_indices.encoding)), - bytes: Some(Box::new(encoded_bytes.encoding)), - }, - ))), - }, - }) +#[cfg(test)] +pub mod tests { + + use arrow_array::{ + builder::{LargeStringBuilder, StringBuilder}, + ArrayRef, LargeStringArray, StringArray, UInt64Array, + }; + use arrow_schema::{DataType, Field}; + use std::{sync::Arc, vec}; + + use crate::testing::{ + check_round_trip_encoding_of_data, check_round_trip_encoding_random, TestCases, + }; + + use super::get_indices_from_string_arrays; + + #[test_log::test(tokio::test)] + async fn test_utf8() { + let field = Field::new("", DataType::Utf8, false); + check_round_trip_encoding_random(field).await; + } + + #[test] + fn test_encode_indices_stitches_offsets() { + // Given two string arrays we might have offsets [5, 10, 15] and [0, 3, 7] + // + // We need to stitch them to [0, 5, 10, 13, 17] + let string_array1 = StringArray::from(vec![Some("abcde"), Some("abcde"), Some("abcde")]); + let string_array1 = Arc::new(string_array1.slice(1, 2)); + let string_array2 = Arc::new(StringArray::from(vec![Some("abc"), Some("abcd")])); + let (offsets, null_adjustment) = + get_indices_from_string_arrays(&[string_array1, string_array2]); + + let expected = Arc::new(UInt64Array::from(vec![5, 10, 13, 17])) as ArrayRef; + assert_eq!(&offsets, &expected); + assert_eq!(null_adjustment, 18); + } + + #[test] + fn test_encode_indices_adjusts_nulls() { + // Null entries in string arrays should be adjusted + let string_array1 = Arc::new(StringArray::from(vec![None, Some("foo")])); + let string_array2 = Arc::new(StringArray::from(vec![Some("foo"), None])); + let string_array3 = Arc::new(StringArray::from(vec![None as Option<&str>, None])); + let (offsets, null_adjustment) = + get_indices_from_string_arrays(&[string_array1, string_array2, string_array3]); + + let expected = Arc::new(UInt64Array::from(vec![7, 3, 6, 13, 13, 13])) as ArrayRef; + assert_eq!(&offsets, &expected); + assert_eq!(null_adjustment, 7); + } + + #[test] + fn test_encode_indices_string_types() { + let string_array = Arc::new(LargeStringArray::from(vec![Some("foo")])); + let large_string_array = Arc::new(LargeStringArray::from(vec![Some("foo")])); + let binary_array = Arc::new(LargeStringArray::from(vec![Some("foo")])); + let large_binary_array = Arc::new(LargeStringArray::from(vec![Some("foo")])); + + for arr in [ + string_array, + large_string_array, + binary_array, + large_binary_array, + ] { + let (offsets, null_adjustment) = get_indices_from_string_arrays(&[arr]); + let expected = Arc::new(UInt64Array::from(vec![3])) as ArrayRef; + assert_eq!(&offsets, &expected); + assert_eq!(null_adjustment, 4); + } + } + + #[test_log::test(tokio::test)] + async fn test_binary() { + let field = Field::new("", DataType::Binary, false); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_large_binary() { + let field = Field::new("", DataType::LargeBinary, true); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_large_utf8() { + let field = Field::new("", DataType::LargeUtf8, true); + check_round_trip_encoding_random(field).await; + } + + #[test_log::test(tokio::test)] + async fn test_simple_utf8() { + let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_range(0..3) + .with_range(1..3) + .with_indices(vec![1, 3]); + check_round_trip_encoding_of_data(vec![Arc::new(string_array)], &test_cases).await; + } + + #[test_log::test(tokio::test)] + async fn test_sliced_utf8() { + let string_array = StringArray::from(vec![Some("abc"), Some("de"), None, Some("fgh")]); + let string_array = string_array.slice(1, 3); + + let test_cases = TestCases::default() + .with_range(0..1) + .with_range(0..2) + .with_range(1..2); + check_round_trip_encoding_of_data(vec![Arc::new(string_array)], &test_cases).await; + } + + #[test_log::test(tokio::test)] + async fn test_empty_strings() { + // Scenario 1: Some strings are empty + + let values = [Some("abc"), Some(""), None]; + // Test empty list at beginning, middle, and end + for order in [[0, 1, 2], [1, 0, 2], [2, 0, 1]] { + let mut string_builder = StringBuilder::new(); + for idx in order { + string_builder.append_option(values[idx]); + } + let string_array = Arc::new(string_builder.finish()); + let test_cases = TestCases::default() + .with_indices(vec![1]) + .with_indices(vec![0]) + .with_indices(vec![2]); + check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases).await; + let test_cases = test_cases.with_batch_size(1); + check_round_trip_encoding_of_data(vec![string_array], &test_cases).await; + } + + // Scenario 2: All strings are empty + + // When encoding an array of empty strings there are no bytes to encode + // which is strange and we want to ensure we handle it + let string_array = Arc::new(StringArray::from(vec![Some(""), None, Some("")])); + + let test_cases = TestCases::default().with_range(0..2).with_indices(vec![1]); + check_round_trip_encoding_of_data(vec![string_array.clone()], &test_cases).await; + let test_cases = test_cases.with_batch_size(1); + check_round_trip_encoding_of_data(vec![string_array], &test_cases).await; + } + + #[test_log::test(tokio::test)] + #[ignore] // This test is quite slow in debug mode + async fn test_jumbo_string() { + // This is an overflow test. We have a list of lists where each list + // has 1Mi items. We encode 5000 of these lists and so we have over 4Gi in the + // offsets range + let mut string_builder = LargeStringBuilder::new(); + // a 1 MiB string + let giant_string = String::from_iter((0..(1024 * 1024)).map(|_| '0')); + for _ in 0..5000 { + string_builder.append_option(Some(&giant_string)); } - // Currently not handling all null cases in this array encoder. - // TODO: Separate behavior for all null rows vs some null rows - // else if null_count == row_count { - // let nullability = pb::nullable::Nullability::AllNulls(pb::nullable::AllNull {}); - - // Ok(EncodedArray { - // buffers: vec![], - // encoding: pb::ArrayEncoding { - // array_encoding: Some(pb::array_encoding::ArrayEncoding::Nullable(Box::new( - // pb::Nullable { - // nullability: Some(nullability), - // }, - // ))), - // }, - // }) - // } - - // let arr_encoding = self.values_encoder.encode(arrays, buffer_index)?; - // let encoding = pb::nullable::Nullability::NoNulls(Box::new(pb::nullable::NoNull { - // values: Some(Box::new(arr_encoding.encoding)), - // })); - // (arr_encoding.buffers, encoding) - // } else if null_count == row_count { - // let encoding = pb::nullable::Nullability::AllNulls(pb::nullable::AllNull {}); - // (vec![], encoding) - // } else { - // let validity_as_arrays = arrays - // .iter() - // .map(|arr| { - // if let Some(nulls) = arr.nulls() { - // Arc::new(BooleanArray::new(nulls.inner().clone(), None)) as ArrayRef - // } else { - // let buff = BooleanBuffer::new_set(arr.len()); - // Arc::new(BooleanArray::new(buff, None)) as ArrayRef - // } - // }) - // .collect::>(); - - // let validity_buffer_index = *buffer_index; - // *buffer_index += 1; - // let validity = BitmapBufferEncoder::default().encode(&validity_as_arrays)?; - // let validity_encoding = Box::new(pb::ArrayEncoding { - // array_encoding: Some(pb::array_encoding::ArrayEncoding::Flat(pb::Flat { - // bits_per_value: 1, - // buffer: Some(pb::Buffer { - // buffer_index: validity_buffer_index, - // buffer_type: pb::buffer::BufferType::Page as i32, - // }), - // compression: None, - // })), - // }); - - // let arr_encoding = self.values_encoder.encode(arrays, buffer_index)?; - // let encoding = pb::nullable::Nullability::SomeNulls(Box::new(pb::nullable::SomeNull { - // validity: Some(validity_encoding), - // values: Some(Box::new(arr_encoding.encoding)), - // })); - - // let mut buffers = arr_encoding.buffers; - // buffers.push(EncodedArrayBuffer { - // parts: validity.parts, - // index: validity_buffer_index, - // }); - // (buffers, encoding) - // }; - - // Ok(EncodedArray { - // buffers, - // encoding: pb::ArrayEncoding { - // array_encoding: Some(pb::array_encoding::ArrayEncoding::Nullable(Box::new( - // pb::Nullable { - // nullability: Some(nullability), - // }, - // ))), - // }, - // }) + let giant_array = Arc::new(string_builder.finish()) as ArrayRef; + let arrs = vec![giant_array]; + + // // We can't validate because our validation relies on concatenating all input arrays + let test_cases = TestCases::default().without_validation(); + check_round_trip_encoding_of_data(arrs, &test_cases).await; } } diff --git a/rust/lance-encoding/src/encodings/physical/bitmap.rs b/rust/lance-encoding/src/encodings/physical/bitmap.rs index 12fe355d88..eaf9dffb84 100644 --- a/rust/lance-encoding/src/encodings/physical/bitmap.rs +++ b/rust/lance-encoding/src/encodings/physical/bitmap.rs @@ -33,7 +33,7 @@ impl DenseBitmapScheduler { impl PageScheduler for DenseBitmapScheduler { fn schedule_ranges( &self, - ranges: &[Range], + ranges: &[Range], scheduler: &Arc, top_level_row: u64, ) -> BoxFuture<'static, Result>> { @@ -43,9 +43,9 @@ impl PageScheduler for DenseBitmapScheduler { .iter() .map(|range| { debug_assert_ne!(range.start, range.end); - let start = self.buffer_offset + range.start as u64 / 8; + let start = self.buffer_offset + range.start / 8; let bit_offset = range.start % 8; - let end = self.buffer_offset + range.end.div_ceil(8) as u64; + let end = self.buffer_offset + range.end.div_ceil(8); let byte_range = start..end; min = min.min(start); max = max.max(end); @@ -84,8 +84,8 @@ impl PageScheduler for DenseBitmapScheduler { struct BitmapData { data: Bytes, - bit_offset: u32, - length: u32, + bit_offset: u64, + length: u64, } struct BitmapDecoder { @@ -95,8 +95,8 @@ struct BitmapDecoder { impl PrimitivePageDecoder for BitmapDecoder { fn decode( &self, - rows_to_skip: u32, - num_rows: u32, + rows_to_skip: u64, + num_rows: u64, _all_null: &mut bool, ) -> Result> { let num_bytes = arrow_buffer::bit_util::ceil(num_rows as usize, 8); diff --git a/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs b/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs index ce6a591226..b9308b0af2 100644 --- a/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs +++ b/rust/lance-encoding/src/encodings/physical/fixed_size_list.rs @@ -37,13 +37,13 @@ impl FixedListScheduler { impl PageScheduler for FixedListScheduler { fn schedule_ranges( &self, - ranges: &[std::ops::Range], + ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, ) -> BoxFuture<'static, Result>> { let expanded_ranges = ranges .iter() - .map(|range| (range.start * self.dimension)..(range.end * self.dimension)) + .map(|range| (range.start * self.dimension as u64)..(range.end * self.dimension as u64)) .collect::>(); trace!( "Expanding {} fsl ranges across {}..{} to item ranges across {}..{}", @@ -61,7 +61,7 @@ impl PageScheduler for FixedListScheduler { let items_decoder = inner_page_decoder.await?; Ok(Box::new(FixedListDecoder { items_decoder, - dimension, + dimension: dimension as u64, }) as Box) } .boxed() @@ -70,14 +70,14 @@ impl PageScheduler for FixedListScheduler { pub struct FixedListDecoder { items_decoder: Box, - dimension: u32, + dimension: u64, } impl PrimitivePageDecoder for FixedListDecoder { fn decode( &self, - rows_to_skip: u32, - num_rows: u32, + rows_to_skip: u64, + num_rows: u64, all_null: &mut bool, ) -> Result> { let rows_to_skip = rows_to_skip * self.dimension; diff --git a/rust/lance-encoding/src/encodings/physical/value.rs b/rust/lance-encoding/src/encodings/physical/value.rs index 046438894c..fb9aa88c73 100644 --- a/rust/lance-encoding/src/encodings/physical/value.rs +++ b/rust/lance-encoding/src/encodings/physical/value.rs @@ -82,7 +82,7 @@ impl ValuePageScheduler { impl PageScheduler for ValuePageScheduler { fn schedule_ranges( &self, - ranges: &[std::ops::Range], + ranges: &[std::ops::Range], scheduler: &Arc, top_level_row: u64, ) -> BoxFuture<'static, Result>> { @@ -91,8 +91,8 @@ impl PageScheduler for ValuePageScheduler { ranges .iter() .map(|range| { - let start = self.buffer_offset + (range.start as u64 * self.bytes_per_value); - let end = self.buffer_offset + (range.end as u64 * self.bytes_per_value); + let start = self.buffer_offset + (range.start * self.bytes_per_value); + let end = self.buffer_offset + (range.end * self.bytes_per_value); min = min.min(start); max = max.max(end); start..end @@ -122,8 +122,8 @@ impl PageScheduler for ValuePageScheduler { ranges .iter() .map(|range| { - let start = (range.start as u64 * bytes_per_value) as usize; - let end = (range.end as u64 * bytes_per_value) as usize; + let start = (range.start * bytes_per_value) as usize; + let end = (range.end * bytes_per_value) as usize; start..end }) .collect::>() @@ -206,16 +206,14 @@ impl ValuePageDecoder { impl PrimitivePageDecoder for ValuePageDecoder { fn decode( &self, - rows_to_skip: u32, - num_rows: u32, + rows_to_skip: u64, + num_rows: u64, _all_null: &mut bool, ) -> Result> { - let num_bytes = self.bytes_per_value * num_rows as u64; + let mut bytes_to_skip = rows_to_skip * self.bytes_per_value; + let mut bytes_to_take = num_rows * self.bytes_per_value; - let mut bytes_to_skip = rows_to_skip as u64 * self.bytes_per_value; - let mut bytes_to_take = num_rows as u64 * self.bytes_per_value; - - let mut dest_buffers = vec![BytesMut::with_capacity(num_bytes as usize)]; + let mut dest_buffers = vec![BytesMut::with_capacity(bytes_to_take as usize)]; let dest = &mut dest_buffers[0]; diff --git a/rust/lance-file/src/v2/reader.rs b/rust/lance-file/src/v2/reader.rs index e4383949e5..3fe182a0d1 100644 --- a/rust/lance-file/src/v2/reader.rs +++ b/rust/lance-file/src/v2/reader.rs @@ -7,7 +7,6 @@ use arrow_schema::Schema as ArrowSchema; use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; use bytes::{Bytes, BytesMut}; use futures::{stream::BoxStream, FutureExt, Stream, StreamExt}; -use lance_arrow::DataTypeExt; use lance_encoding::{ decoder::{ BatchDecodeStream, ColumnInfo, DecodeBatchScheduler, DecoderMiddlewareChain, @@ -37,10 +36,7 @@ use crate::{ format::{pb, pbfile, MAGIC, MAJOR_VERSION, MINOR_VERSION_NEXT}, }; -use lance_encoding::encoder::get_str_encoding_type; - use super::io::LanceEncodingsIo; -use arrow_schema::DataType; // For now, we don't use global buffers for anything other than schema. If we // use these later we should make them lazily loaded and then cached once loaded. @@ -459,15 +455,11 @@ impl FileReader { // Helper function for `default_projection` to determine how many columns are occupied // by a lance field. fn default_column_count(field: &Field) -> u32 { - if field.data_type().is_binary_like() { - 2 - } else { - 1 + field - .children - .iter() - .map(Self::default_column_count) - .sum::() - } + 1 + field + .children + .iter() + .map(Self::default_column_count) + .sum::() } // This function is one of the few spots in the reader where we rely on Lance table @@ -529,19 +521,6 @@ impl FileReader { column_infos.push(self.metadata.column_infos[*column_idx].clone()); *column_idx += 1; - if get_str_encoding_type() { - // use str array encoding - if (field.data_type().is_binary_like()) && (field.data_type() != DataType::Utf8) { - // These types are 2 columns in a lance file but a single field id in a lance schema - column_infos.push(self.metadata.column_infos[*column_idx].clone()); - *column_idx += 1; - } - } else if field.data_type().is_binary_like() { - // These types are 2 columns in a lance file but a single field id in a lance schema - column_infos.push(self.metadata.column_infos[*column_idx].clone()); - *column_idx += 1; - } - for child in &field.children { self.collect_columns(child, column_idx, column_infos)?; } @@ -950,12 +929,7 @@ impl EncodedBatchReaderExt for EncodedBatch { data: bytes, num_rows: page_table .first() - .map(|col| { - col.page_infos - .iter() - .map(|page| page.num_rows as u64) - .sum::() - }) + .map(|col| col.page_infos.iter().map(|page| page.num_rows).sum::()) .unwrap_or(0), page_table, schema: Arc::new(schema.clone()), @@ -998,12 +972,7 @@ impl EncodedBatchReaderExt for EncodedBatch { data: bytes, num_rows: page_table .first() - .map(|col| { - col.page_infos - .iter() - .map(|page| page.num_rows as u64) - .sum::() - }) + .map(|col| col.page_infos.iter().map(|page| page.num_rows).sum::()) .unwrap_or(0), page_table, schema: Arc::new(schema.clone()), @@ -1018,7 +987,6 @@ pub mod tests { use arrow_array::{ types::{Float64Type, Int32Type}, RecordBatch, - // StringArray, UInt32Array, }; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema}; use bytes::Bytes; @@ -1030,11 +998,8 @@ pub mod tests { decoder::{decode_batch, DecoderMiddlewareChain, FilterExpression}, encoder::{encode_batch, CoreFieldEncodingStrategy, EncodedBatch}, }; - // use lance_io::object_store::ObjectStore; - // use lance_io::scheduler::ScanScheduler; use lance_io::stream::RecordBatchStream; use log::debug; - // use object_store::path::Path; use crate::v2::{ reader::{EncodedBatchReaderExt, FileReader, ReaderProjection}, @@ -1053,6 +1018,8 @@ pub mod tests { .col("score", array::rand::()) .col("location", array::rand_type(&location_type)) .col("categories", array::rand_type(&categories_type)) + .col("binary", array::rand_type(&DataType::Binary)) + .col("large_bin", array::rand_type(&DataType::LargeBinary)) .into_reader_rows(RowCount::from(1000), BatchCount::from(100)); write_lance_file(reader, fs, FileWriterOptions::default()).await @@ -1450,174 +1417,4 @@ pub mod tests { let buf = file_reader.read_global_buffer(1).await.unwrap(); assert_eq!(buf, test_bytes); } - - // fn test_reading_rangefrom( - // schema: Arc, - // ) -> (lance_io::ReadBatchParams, Vec) { - // let result_batch = RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![3, 4, 5, 6])), - // Arc::new(StringArray::from(vec!["abcd", "apple", "hello", "abcd"])), - // ], - // ) - // .unwrap(); - - // let result_batches = vec![result_batch]; - // let read_params = lance_io::ReadBatchParams::RangeFrom(2..); - - // (read_params, result_batches) - // } - - // fn test_reading_rangeto( - // schema: Arc, - // ) -> (lance_io::ReadBatchParams, Vec) { - // let read_params = lance_io::ReadBatchParams::RangeTo(..4); - // let result_batch = RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![1, 2, 3, 4])), - // Arc::new(StringArray::from(vec!["abcd", "hello", "abcd", "apple"])), - // ], - // ) - // .unwrap(); - - // let result_batches = vec![result_batch]; - - // (read_params, result_batches) - // } - - // fn test_reading_random_indices( - // schema: Arc, - // ) -> (lance_io::ReadBatchParams, Vec) { - // let row_indices_vec = vec![0, 2, 4]; - // let row_indices = UInt32Array::from(row_indices_vec); - // let read_params = lance_io::ReadBatchParams::from(row_indices); - // let result_batch = RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![1, 3, 5])), - // Arc::new(StringArray::from(vec!["abcd", "abcd", "hello"])), - // ], - // ) - // .unwrap(); - - // let result_batches = vec![result_batch]; - - // (read_params, result_batches) - // } - - // fn test_reading_partial_range( - // schema: Arc, - // ) -> (lance_io::ReadBatchParams, Vec) { - // let read_params = lance_io::ReadBatchParams::Range(2..4); - // let result_batch = RecordBatch::try_new( - // schema, - // vec![ - // Arc::new(UInt32Array::from(vec![3, 4])), - // Arc::new(StringArray::from(vec!["abcd", "apple"])), - // ], - // ) - // .unwrap(); - - // let result_batches = vec![result_batch]; - - // (read_params, result_batches) - // } - - // #[tokio::test] - // async fn test_string_array_encoding() { - // // set env var temporarily to test string array encoding - // let _env_guard = EnvVarGuard::new("LANCE_STR_ARRAY_ENCODING", "binary"); - - // let tmp_dir = tempfile::tempdir().unwrap(); - // let tmp_path: String = tmp_dir.path().to_str().unwrap().to_owned(); - // let tmp_path = Path::parse(tmp_path).unwrap(); - // let tmp_path = tmp_path.child("some_file.lance"); - // let obj_store = Arc::new(ObjectStore::local()); - // let writer = obj_store.create(&tmp_path).await.unwrap(); - - // let schema = Arc::new(ArrowSchema::new(vec![ - // Field::new("key", DataType::UInt32, false), - // Field::new("strings", DataType::Utf8, false), - // ])); - - // let batch1 = RecordBatch::try_new( - // schema.clone(), - // vec![ - // Arc::new(UInt32Array::from(vec![1, 2, 3])), - // Arc::new(StringArray::from(vec!["abcd", "hello", "abcd"])), - // ], - // ) - // .unwrap(); - - // let batch2 = RecordBatch::try_new( - // schema.clone(), - // vec![ - // Arc::new(UInt32Array::from(vec![4, 5, 6])), - // Arc::new(StringArray::from(vec!["apple", "hello", "abcd"])), - // ], - // ) - // .unwrap(); - - // let batches = vec![batch1, batch2]; - // let lance_schema = lance_core::datatypes::Schema::try_from(schema.as_ref()).unwrap(); - - // let mut file_writer = FileWriter::try_new( - // writer, - // tmp_path.to_string(), - // lance_schema, - // FileWriterOptions::default(), - // ) - // .unwrap(); - - // for batch in batches.clone() { - // file_writer.write_batch(&batch).await.unwrap(); - // } - - // file_writer.finish().await.unwrap(); - - // let object_store = Arc::new(ObjectStore::local()); - // let fs_scheduler = ScanScheduler::new(object_store.clone(), 8); - // let file_scheduler = fs_scheduler.open_file(&tmp_path).await.unwrap(); - - // let file_reader = - // FileReader::try_open(file_scheduler, None, DecoderMiddlewareChain::default()) - // .await - // .unwrap(); - - // for batch_size in [1, 2, 1024] { - // // Read different types of ranges from the file - // let (read_params, result_batches) = test_reading_rangefrom(schema.clone()); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - - // let (read_params, result_batches) = test_reading_rangeto(schema.clone()); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - - // let (read_params, result_batches) = test_reading_random_indices(schema.clone()); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - - // let (read_params, result_batches) = test_reading_partial_range(schema.clone()); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - - // let read_params = lance_io::ReadBatchParams::RangeFull; - // let result_batches = batches.clone(); - // let batch_stream = file_reader - // .read_stream(read_params, batch_size, 16, FilterExpression::no_filter()) - // .unwrap(); - // verify_expected(result_batches.as_slice(), batch_stream, batch_size, None).await; - // } - // } } From f51c5f0608bba944881d6e220ad47af739da18ff Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 25 Jun 2024 13:17:53 +0800 Subject: [PATCH 13/13] refactor: move IVF_HNSW_SQ & IVF_FLAT to new buliding & search path (#2469) - IVF_HNSW_SQ for new search/build path - IVF_FLAT e2e pass - support to train quantizer with new index builder - fix partition order broken after building - clean IVF related types - index builder method chaining for customizing - impl merging deltas for new IVF_HNSW_SQ & IVF_FLAT --------- Signed-off-by: BubbleCal --- python/python/lance/util.py | 8 +- python/src/lib.rs | 2 - python/src/utils.rs | 52 +- rust/lance-arrow/src/lib.rs | 24 + rust/lance-encoding/src/decoder.rs | 2 +- rust/lance-index/benches/find_partitions.rs | 8 +- rust/lance-index/src/lib.rs | 3 + rust/lance-index/src/scalar/btree.rs | 7 + rust/lance-index/src/scalar/flat.rs | 10 +- rust/lance-index/src/vector.rs | 24 +- rust/lance-index/src/vector/flat/index.rs | 87 ++- rust/lance-index/src/vector/flat/storage.rs | 2 +- rust/lance-index/src/vector/graph.rs | 3 +- rust/lance-index/src/vector/hnsw/builder.rs | 52 +- rust/lance-index/src/vector/hnsw/index.rs | 43 +- rust/lance-index/src/vector/ivf.rs | 60 +- rust/lance-index/src/vector/ivf/shuffler.rs | 4 +- rust/lance-index/src/vector/ivf/storage.rs | 191 +++--- rust/lance-index/src/vector/ivf/transform.rs | 6 +- rust/lance-index/src/vector/pq/storage.rs | 5 +- rust/lance-index/src/vector/quantizer.rs | 100 +++- rust/lance-index/src/vector/sq.rs | 15 + rust/lance-index/src/vector/sq/builder.rs | 47 +- rust/lance-index/src/vector/sq/storage.rs | 32 +- rust/lance-index/src/vector/storage.rs | 71 +-- rust/lance-index/src/vector/v3/shuffler.rs | 41 +- rust/lance-index/src/vector/v3/subindex.rs | 41 +- rust/lance/src/index.rs | 84 ++- rust/lance/src/index/vector.rs | 84 +-- rust/lance/src/index/vector/builder.rs | 557 +++++++++++++----- rust/lance/src/index/vector/fixture_test.rs | 40 +- rust/lance/src/index/vector/ivf.rs | 579 +++++++------------ rust/lance/src/index/vector/ivf/builder.rs | 21 +- rust/lance/src/index/vector/ivf/io.rs | 58 +- rust/lance/src/index/vector/ivf/v2.rs | 409 +++++++++---- rust/lance/src/index/vector/pq.rs | 51 +- rust/lance/src/index/vector/sq.rs | 75 --- rust/lance/src/io/exec/knn.rs | 20 +- rust/lance/src/session/index_extension.rs | 40 +- 39 files changed, 1772 insertions(+), 1186 deletions(-) delete mode 100644 rust/lance/src/index/vector/sq.rs diff --git a/python/python/lance/util.py b/python/python/lance/util.py index b3ce35ded6..208d8b41ff 100644 --- a/python/python/lance/util.py +++ b/python/python/lance/util.py @@ -11,7 +11,7 @@ from .dependencies import _check_for_numpy, _check_for_pandas from .dependencies import numpy as np from .dependencies import pandas as pd -from .lance import _build_sq_storage, _Hnsw, _KMeans +from .lance import _Hnsw, _KMeans if TYPE_CHECKING: ts_types = Union[datetime, pd.Timestamp, str] @@ -245,9 +245,3 @@ def to_lance_file(self, file_path): def vectors(self) -> pa.Array: return self._hnsw.vectors() - - -def build_sq_storage( - row_ids_array: Iterator[pa.Array], vectors_array: pa.Array, dim, bounds: tuple -) -> pa.RecordBatch: - return _build_sq_storage(row_ids_array, vectors_array, dim, bounds) diff --git a/python/src/lib.rs b/python/src/lib.rs index e182582fdb..b90740cf5f 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -68,7 +68,6 @@ pub(crate) mod utils; pub use crate::arrow::{bfloat16_array, BFloat16}; use crate::fragment::{cleanup_partial_writes, write_fragments}; pub use crate::tracing::{trace_to_chrome, TraceGuard}; -use crate::utils::build_sq_storage; use crate::utils::Hnsw; use crate::utils::KMeans; pub use dataset::write_dataset; @@ -142,7 +141,6 @@ fn lance(py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(cleanup_partial_writes))?; m.add_wrapped(wrap_pyfunction!(trace_to_chrome))?; m.add_wrapped(wrap_pyfunction!(manifest_needs_migration))?; - m.add_wrapped(wrap_pyfunction!(build_sq_storage))?; // Debug functions m.add_wrapped(wrap_pyfunction!(debug::format_schema))?; m.add_wrapped(wrap_pyfunction!(debug::format_manifest))?; diff --git a/python/src/utils.rs b/python/src/utils.rs index 960870f0f7..101785de24 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -14,23 +14,18 @@ use std::sync::Arc; -use arrow::compute::{concat, concat_batches}; +use arrow::compute::concat; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; -use arrow_array::{ - cast::AsArray, Array, FixedSizeListArray, Float32Array, UInt32Array, UInt64Array, -}; +use arrow_array::{cast::AsArray, Array, FixedSizeListArray, Float32Array, UInt32Array}; use arrow_data::ArrayData; use arrow_schema::DataType; use lance::Result; -use lance::{datatypes::Schema, index::vector::sq, io::ObjectStore}; +use lance::{datatypes::Schema, io::ObjectStore}; use lance_arrow::FixedSizeListArrayExt; use lance_file::writer::FileWriter; use lance_index::scalar::IndexWriter; +use lance_index::vector::hnsw::{builder::HnswBuildParams, HNSW}; use lance_index::vector::v3::subindex::IvfSubIndex; -use lance_index::vector::{ - hnsw::{builder::HnswBuildParams, HNSW}, - storage::VectorStore, -}; use lance_linalg::kmeans::compute_partitions; use lance_linalg::{ distance::DistanceType, @@ -41,7 +36,7 @@ use object_store::path::Path; use pyo3::{ exceptions::{PyIOError, PyRuntimeError, PyValueError}, prelude::*, - types::{PyIterator, PyTuple}, + types::PyIterator, }; use crate::RT; @@ -220,40 +215,3 @@ impl Hnsw { self.vectors.to_data().to_pyarrow(py) } } - -#[pyfunction(name = "_build_sq_storage")] -pub fn build_sq_storage( - py: Python, - row_ids_array: &PyIterator, - vectors: &PyAny, - dim: usize, - bounds: &PyTuple, -) -> PyResult { - let mut row_ids_arr: Vec> = Vec::new(); - for row_ids in row_ids_array { - let row_ids = ArrayData::from_pyarrow(row_ids?)?; - if !matches!(row_ids.data_type(), DataType::UInt64) { - return Err(PyValueError::new_err("Must be a UInt64")); - } - row_ids_arr.push(Arc::new(UInt64Array::from(row_ids))); - } - let row_ids_refs = row_ids_arr.iter().map(|a| a.as_ref()).collect::>(); - let row_ids = concat(&row_ids_refs).map_err(|e| PyIOError::new_err(e.to_string()))?; - std::mem::drop(row_ids_arr); - - let vectors = Arc::new(FixedSizeListArray::from(ArrayData::from_pyarrow(vectors)?)); - - let lower_bound = bounds.get_item(0)?.extract::()?; - let upper_bound = bounds.get_item(1)?.extract::()?; - let quantizer = - lance_index::vector::sq::ScalarQuantizer::with_bounds(8, dim, lower_bound..upper_bound); - let storage = sq::build_sq_storage(DistanceType::L2, row_ids, vectors, quantizer) - .map_err(|e| PyIOError::new_err(e.to_string()))?; - let batches = storage - .to_batches() - .map_err(|e| PyIOError::new_err(e.to_string()))? - .collect::>(); - let batch = concat_batches(&batches[0].schema(), &batches) - .map_err(|e| PyIOError::new_err(e.to_string()))?; - batch.to_pyarrow(py) -} diff --git a/rust/lance-arrow/src/lib.rs b/rust/lance-arrow/src/lib.rs index e1cf554d59..059221c268 100644 --- a/rust/lance-arrow/src/lib.rs +++ b/rust/lance-arrow/src/lib.rs @@ -5,6 +5,7 @@ //! //! To improve Arrow-RS ergonomic +use std::collections::HashMap; use std::sync::Arc; use arrow_array::{ @@ -374,6 +375,19 @@ pub trait RecordBatchExt { /// Project the schema over the [RecordBatch]. fn project_by_schema(&self, schema: &Schema) -> Result; + /// metadata of the schema. + fn metadata(&self) -> &HashMap; + + /// Add metadata to the schema. + fn add_metadata(&self, key: String, value: String) -> Result { + let mut metadata = self.metadata().clone(); + metadata.insert(key, value); + self.with_metadata(metadata) + } + + /// Replace the schema metadata with the provided one. + fn with_metadata(&self, metadata: HashMap) -> Result; + /// Take selected rows from the [RecordBatch]. fn take(&self, indices: &UInt32Array) -> Result; } @@ -460,6 +474,16 @@ impl RecordBatchExt for RecordBatch { self.try_new_from_struct_array(project(&struct_array, schema.fields())?) } + fn metadata(&self) -> &HashMap { + self.schema_ref().metadata() + } + + fn with_metadata(&self, metadata: HashMap) -> Result { + let mut schema = self.schema_ref().as_ref().clone(); + schema.metadata = metadata; + Self::try_new(schema.into(), self.columns().into()) + } + fn take(&self, indices: &UInt32Array) -> Result { let struct_array: StructArray = self.clone().into(); let taken = take(&struct_array, indices, None)?; diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index 0ecf127c0a..e110680e1f 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -583,7 +583,7 @@ impl CoreFieldDecoderStrategy { /// Helper method to verify the page encoding of a struct header column fn check_simple_struct(column_info: &ColumnInfo, path: &VecDeque) -> Result<()> { Self::ensure_values_encoded(column_info, path)?; - if !column_info.page_infos.len() == 1 { + if column_info.page_infos.len() != 1 { return Err(Error::InvalidInput { source: format!("Due to schema we expected a struct column but we received a column with {} pages and right now we only support struct columns with 1 page", column_info.page_infos.len()).into(), location: location!() }); } let encoding = &column_info.page_infos[0].encoding; diff --git a/rust/lance-index/benches/find_partitions.rs b/rust/lance-index/benches/find_partitions.rs index 9e68602684..dd370128f0 100644 --- a/rust/lance-index/benches/find_partitions.rs +++ b/rust/lance-index/benches/find_partitions.rs @@ -11,7 +11,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; #[cfg(target_os = "linux")] use pprof::criterion::{Output, PProfProfiler}; -use lance_index::vector::ivf::Ivf; +use lance_index::vector::ivf::IvfTransformer; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_seed; @@ -27,14 +27,14 @@ fn bench_partitions(c: &mut Criterion) { let fsl = FixedSizeListArray::try_new_from_values(centroids, DIMENSION as i32).unwrap(); for k in &[1, 10, 50] { - let ivf = Ivf::new(fsl.clone(), DistanceType::L2, vec![]); + let ivf = IvfTransformer::new(fsl.clone(), DistanceType::L2, vec![]); c.bench_function(format!("IVF{},k={},L2", num_centroids, k).as_str(), |b| { b.iter(|| { let _ = ivf.find_partitions(&query, *k); }) }); - let ivf = Ivf::new(fsl.clone(), DistanceType::Cosine, vec![]); + let ivf = IvfTransformer::new(fsl.clone(), DistanceType::Cosine, vec![]); c.bench_function( format!("IVF{},k={},Cosine", num_centroids, k).as_str(), |b| { @@ -45,7 +45,7 @@ fn bench_partitions(c: &mut Criterion) { ); } - let ivf = Ivf::new(fsl.clone(), DistanceType::L2, vec![]); + let ivf = IvfTransformer::new(fsl.clone(), DistanceType::L2, vec![]); let batch = generate_random_array_with_seed::(DIMENSION * 4096, SEED); let fsl = FixedSizeListArray::try_new_from_values(batch, DIMENSION as i32).unwrap(); c.bench_function( diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 95492daf5f..18521f9222 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -52,6 +52,9 @@ pub trait Index: Send + Sync + DeepSizeOf { /// Cast to [Index] fn as_index(self: Arc) -> Arc; + /// Cast to [vector::VectorIndex] + fn as_vector_index(self: Arc) -> Result>; + /// Retrieve index statistics as a JSON Value fn statistics(&self) -> Result; diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 39916cab2c..752d56cc6d 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -789,6 +789,13 @@ impl Index for BTreeIndex { self } + fn as_vector_index(self: Arc) -> Result> { + Err(Error::NotSupported { + source: "BTreeIndex is not vector index".into(), + location: location!(), + }) + } + fn index_type(&self) -> IndexType { IndexType::Scalar } diff --git a/rust/lance-index/src/scalar/flat.rs b/rust/lance-index/src/scalar/flat.rs index 2d7ac03bce..842773929a 100644 --- a/rust/lance-index/src/scalar/flat.rs +++ b/rust/lance-index/src/scalar/flat.rs @@ -14,8 +14,9 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_physical_expr::expressions::{in_list, lit, Column}; use deepsize::DeepSizeOf; use lance_core::utils::address::RowAddress; -use lance_core::Result; +use lance_core::{Error, Result}; use roaring::RoaringBitmap; +use snafu::{location, Location}; use crate::{Index, IndexType}; @@ -157,6 +158,13 @@ impl Index for FlatIndex { self } + fn as_vector_index(self: Arc) -> Result> { + Err(Error::NotSupported { + source: "FlatIndex is not vector index".into(), + location: location!(), + }) + } + fn index_type(&self) -> IndexType { IndexType::Scalar } diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index ebaee29f1f..63ad4955f3 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -6,13 +6,16 @@ use std::{collections::HashMap, sync::Arc}; -use arrow_array::{ArrayRef, RecordBatch}; +use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; use arrow_schema::Field; use async_trait::async_trait; +use ivf::storage::IvfModel; use lance_core::{Result, ROW_ID_FIELD}; use lance_io::traits::Reader; use lance_linalg::distance::DistanceType; use lazy_static::lazy_static; +use quantizer::{QuantizationType, Quantizer}; +use v3::subindex::SubIndexType; pub mod bq; pub mod flat; @@ -102,6 +105,7 @@ impl From for pb::VectorMetricType { } /// Vector Index for (Approximate) Nearest Neighbor (ANN) Search. +/// It's always the IVF index, any other index types without partitioning will be treated as IVF with one partition. #[async_trait] #[allow(clippy::redundant_pub_crate)] pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { @@ -125,6 +129,15 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// - Only supports `f32` now. Will add f64/f16 later. async fn search(&self, query: &Query, pre_filter: Arc) -> Result; + fn find_partitions(&self, query: &Query) -> Result; + + async fn search_in_partition( + &self, + partition_id: usize, + query: &Query, + pre_filter: Arc, + ) -> Result; + /// If the index is loadable by IVF, so it can be a sub-index that /// is loaded on demand by IVF. fn is_loadable(&self) -> bool; @@ -136,6 +149,9 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// explaining why not fn check_can_remap(&self) -> Result<()>; + // async fn append(&self, batches: Vec) -> Result<()>; + // async fn merge(&self, indices: Vec>) -> Result<()>; + /// Load the index from the reader on-demand. async fn load( &self, @@ -170,4 +186,10 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// The metric type of this vector index. fn metric_type(&self) -> DistanceType; + + fn ivf_model(&self) -> IvfModel; + fn quantizer(&self) -> Quantizer; + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType); } diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index 0a7c7d97f8..8ffbfe97c0 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -6,14 +6,16 @@ use std::{collections::HashSet, sync::Arc}; +use arrow::array::AsArray; use arrow_array::{Array, ArrayRef, Float32Array, RecordBatch, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use deepsize::DeepSizeOf; use itertools::Itertools; -use lance_core::{Result, ROW_ID_FIELD}; +use lance_core::{Error, Result, ROW_ID_FIELD}; use lance_file::reader::FileReader; use lance_linalg::distance::DistanceType; use serde::{Deserialize, Serialize}; +use snafu::{location, Location}; use crate::{ prefilter::PreFilter, @@ -61,6 +63,10 @@ impl IvfSubIndex for FlatIndex { "FLAT" } + fn metadata_key() -> &'static str { + "lance:flat" + } + fn schema() -> arrow_schema::SchemaRef { Schema::new(vec![Field::new("__flat_marker", DataType::UInt64, false)]).into() } @@ -74,25 +80,44 @@ impl IvfSubIndex for FlatIndex { prefilter: Arc, ) -> Result { let dist_calc = storage.dist_calculator(query); - let filtered_row_ids = prefilter - .filter_row_ids(Box::new(storage.row_ids())) - .into_iter() - .collect::>(); - let (row_ids, dists): (Vec, Vec) = (0..storage.len()) - .filter(|&id| !filtered_row_ids.contains(&storage.row_id(id as u32))) - .map(|id| OrderedNode { - id: id as u32, - dist: OrderedFloat(dist_calc.distance(id as u32)), - }) - .sorted_unstable() - .take(k) - .map( - |OrderedNode { - id, - dist: OrderedFloat(dist), - }| (storage.row_id(id), dist), - ) - .unzip(); + + let (row_ids, dists): (Vec, Vec) = match prefilter.is_empty() { + true => (0..storage.len()) + .map(|id| OrderedNode { + id: id as u32, + dist: OrderedFloat(dist_calc.distance(id as u32)), + }) + .sorted_unstable() + .take(k) + .map( + |OrderedNode { + id, + dist: OrderedFloat(dist), + }| (storage.row_id(id), dist), + ) + .unzip(), + false => { + let filtered_row_ids = prefilter + .filter_row_ids(Box::new(storage.row_ids())) + .into_iter() + .collect::>(); + (0..storage.len()) + .filter(|&id| !filtered_row_ids.contains(&storage.row_id(id as u32))) + .map(|id| OrderedNode { + id: id as u32, + dist: OrderedFloat(dist_calc.distance(id as u32)), + }) + .sorted_unstable() + .take(k) + .map( + |OrderedNode { + id, + dist: OrderedFloat(dist), + }| (storage.row_id(id), dist), + ) + .unzip() + } + }; let (row_ids, dists) = (UInt64Array::from(row_ids), Float32Array::from(dists)); @@ -143,9 +168,15 @@ impl FlatQuantizer { } impl Quantization for FlatQuantizer { + type BuildParams = (); type Metadata = FlatMetadata; type Storage = FlatStorage; + fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result { + let dim = data.as_fixed_size_list().value_length(); + Ok(Self::new(dim as usize, distance_type)) + } + fn code_dim(&self) -> usize { self.dim } @@ -173,7 +204,7 @@ impl Quantization for FlatQuantizer { "flat" } - fn quantization_type(&self) -> QuantizationType { + fn quantization_type() -> QuantizationType { QuantizationType::Flat } @@ -187,3 +218,17 @@ impl From for Quantizer { Self::Flat(value) } } + +impl TryFrom for FlatQuantizer { + type Error = Error; + + fn try_from(value: Quantizer) -> Result { + match value { + Quantizer::Flat(quantizer) => Ok(quantizer), + _ => Err(Error::invalid_input( + "quantizer is not FlatQuantizer", + location!(), + )), + } + } +} diff --git a/rust/lance-index/src/vector/flat/storage.rs b/rust/lance-index/src/vector/flat/storage.rs index 62769dc151..7173697233 100644 --- a/rust/lance-index/src/vector/flat/storage.rs +++ b/rust/lance-index/src/vector/flat/storage.rs @@ -26,7 +26,7 @@ use super::index::FlatMetadata; pub const FLAT_COLUMN: &str = "flat"; /// All data are stored in memory -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct FlatStorage { batch: RecordBatch, distance_type: DistanceType, diff --git a/rust/lance-index/src/vector/graph.rs b/rust/lance-index/src/vector/graph.rs index fde0d5229e..3cc1bd425d 100644 --- a/rust/lance-index/src/vector/graph.rs +++ b/rust/lance-index/src/vector/graph.rs @@ -232,9 +232,8 @@ pub fn beam_search( dist_calc: &impl DistCalculator, bitset: Option<&roaring::bitmap::RoaringBitmap>, prefetch_distance: Option, - visited_generator: &mut VisitedGenerator, + visited: &mut Visited, ) -> Vec { - let mut visited = visited_generator.generate(graph.len()); //let mut visited: HashSet<_> = HashSet::with_capacity(k); let mut candidates = BinaryHeap::with_capacity(k); visited.insert(ep.id); diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index bd664abf4b..e9c487b4d0 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -54,9 +54,6 @@ pub struct HnswBuildParams { /// size of the dynamic list for the candidates pub ef_construction: usize, - /// the max number of threads to use for building the graph - pub parallel_limit: Option, - /// number of vectors ahead to prefetch while building the graph pub prefetch_distance: Option, } @@ -67,7 +64,6 @@ impl Default for HnswBuildParams { max_level: 7, m: 20, ef_construction: 150, - parallel_limit: None, prefetch_distance: Some(2), } } @@ -97,12 +93,6 @@ impl HnswBuildParams { self } - /// The max number of threads to use for building the graph. - pub fn parallel_limit(mut self, limit: usize) -> Self { - self.parallel_limit = Some(limit); - self - } - /// Build the HNSW index from the given data. /// /// # Parameters @@ -136,6 +126,18 @@ impl Debug for HNSW { } impl HNSW { + pub fn empty() -> Self { + Self { + inner: Arc::new(HnswBuilder { + params: HnswBuildParams::default(), + nodes: Arc::new(Vec::new()), + level_count: Vec::new(), + entry_point: 0, + visited_generator_queue: Arc::new(ArrayQueue::new(1)), + }), + } + } + pub fn len(&self) -> usize { self.inner.nodes.len() } @@ -176,6 +178,7 @@ impl HNSW { } let bottom_level = HnswBottomView::new(nodes); + let mut visited = visited_generator.generate(storage.len()); Ok(beam_search( &bottom_level, &ep, @@ -183,7 +186,7 @@ impl HNSW { &dist_calc, bitset.as_ref(), prefetch_distance, - visited_generator, + &mut visited, ) .into_iter() .take(k) @@ -415,6 +418,7 @@ impl HnswBuilder { visited_generator: &mut VisitedGenerator, ) -> Vec { let cur_level = HnswLevelView::new(level, nodes); + let mut visited = visited_generator.generate(nodes.len()); beam_search( &cur_level, ep, @@ -422,7 +426,7 @@ impl HnswBuilder { dist_calc, None, self.params.prefetch_distance, - visited_generator, + &mut visited, ) } @@ -511,6 +515,10 @@ impl IvfSubIndex for HNSW { where Self: Sized, { + if data.num_rows() == 0 { + return Ok(Self::empty()); + } + let hnsw_metadata = data.schema_ref() .metadata() @@ -519,7 +527,14 @@ impl IvfSubIndex for HNSW { message: format!("{} not found", HNSW_METADATA_KEY), location: location!(), })?; - let hnsw_metadata: HnswMetadata = serde_json::from_str(hnsw_metadata)?; + let hnsw_metadata: HnswMetadata = + serde_json::from_str(hnsw_metadata).map_err(|e| Error::Index { + message: format!( + "Failed to decode HNSW metadata: {}, json: {}", + e, hnsw_metadata + ), + location: location!(), + })?; let levels: Vec<_> = hnsw_metadata .level_offsets @@ -583,6 +598,10 @@ impl IvfSubIndex for HNSW { HNSW_TYPE } + fn metadata_key() -> &'static str { + "lance:hnsw" + } + /// Return the schema of the sub index fn schema() -> arrow_schema::SchemaRef { arrow_schema::Schema::new(vec![ @@ -661,13 +680,6 @@ impl IvfSubIndex for HNSW { ); let len = storage.len(); - let parallel_limit = hnsw - .inner - .params - .parallel_limit - .unwrap_or_else(num_cpus::get) - .max(1); - log::info!("Building HNSW graph with parallel_limit={}", parallel_limit); hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed); (1..len).into_par_iter().for_each(|node| { let mut visited_generator = VisitedGenerator::new(len); diff --git a/rust/lance-index/src/vector/hnsw/index.rs b/rust/lance-index/src/vector/hnsw/index.rs index 01a4b456f7..783372b9f1 100644 --- a/rust/lance-index/src/vector/hnsw/index.rs +++ b/rust/lance-index/src/vector/hnsw/index.rs @@ -8,7 +8,7 @@ use std::{ sync::Arc, }; -use arrow_array::RecordBatch; +use arrow_array::{RecordBatch, UInt32Array}; use async_trait::async_trait; use deepsize::DeepSizeOf; use lance_core::{datatypes::Schema, Error, Result}; @@ -22,7 +22,9 @@ use snafu::{location, Location}; use tracing::instrument; use crate::prefilter::PreFilter; -use crate::vector::v3::subindex::IvfSubIndex; +use crate::vector::ivf::storage::IvfModel; +use crate::vector::quantizer::QuantizationType; +use crate::vector::v3::subindex::{IvfSubIndex, SubIndexType}; use crate::{ vector::{ graph::NEIGHBORS_FIELD, @@ -35,8 +37,6 @@ use crate::{ Index, IndexType, }; -use super::builder::HNSW_METADATA_KEY; - #[derive(Clone, DeepSizeOf)] pub struct HNSWIndexOptions { pub use_residual: bool, @@ -118,6 +118,11 @@ impl Index for HNSWIndex { self } + /// Cast to [VectorIndex] + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + /// Retrieve index statistics as a JSON Value fn statistics(&self) -> Result { Ok(json!({ @@ -166,6 +171,19 @@ impl VectorIndex for HNSWIndex { ) } + fn find_partitions(&self, _: &Query) -> Result { + unimplemented!("only for IVF") + } + + async fn search_in_partition( + &self, + _: usize, + _: &Query, + _: Arc, + ) -> Result { + unimplemented!("only for IVF") + } + fn is_loadable(&self) -> bool { true } @@ -230,7 +248,7 @@ impl VectorIndex for HNSWIndex { .await?; let mut schema = batch.schema_ref().as_ref().clone(); schema.metadata.insert( - HNSW_METADATA_KEY.to_string(), + HNSW::metadata_key().to_owned(), serde_json::to_string(&metadata)?, ); let batch = batch.with_schema(schema.into())?; @@ -256,6 +274,21 @@ impl VectorIndex for HNSWIndex { }) } + fn ivf_model(&self) -> IvfModel { + unimplemented!("only for IVF") + } + + fn quantizer(&self) -> Quantizer { + self.partition_storage.quantizer().clone() + } + + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + ( + SubIndexType::Hnsw, + self.partition_storage.quantizer().quantization_type(), + ) + } + fn metric_type(&self) -> DistanceType { self.partition_storage.distance_type() } diff --git a/rust/lance-index/src/vector/ivf.rs b/rust/lance-index/src/vector/ivf.rs index 4e5b8425d8..87fcdf164d 100644 --- a/rust/lance-index/src/vector/ivf.rs +++ b/rust/lance-index/src/vector/ivf.rs @@ -15,14 +15,13 @@ use lance_linalg::{ kmeans::{compute_partitions_arrow_array, kmeans_find_partitions_arrow_array}, }; -use crate::vector::ivf::transform::IvfTransformer; +use crate::vector::ivf::transform::PartitionTransformer; use crate::vector::{ pq::{transform::PQTransformer, ProductQuantizer}, residual::ResidualTransform, transform::Transformer, }; -use super::transform::DropColumn; use super::{quantizer::Quantizer, residual::compute_residual}; use super::{PART_ID_COLUMN, PQ_CODE_COLUMN, RESIDUAL_COLUMN}; @@ -40,38 +39,48 @@ mod transform; /// - *metric_type*: metric type to compute pair-wise vector distance. /// - *transforms*: a list of transforms to apply to the vector column. /// - *range*: only covers a range of partitions. Default is None -pub fn new_ivf( +pub fn new_ivf_transformer( centroids: FixedSizeListArray, metric_type: DistanceType, transforms: Vec>, -) -> Ivf { - Ivf::new(centroids, metric_type, transforms) +) -> IvfTransformer { + IvfTransformer::new(centroids, metric_type, transforms) } -pub fn new_ivf_with_quantizer( +pub fn new_ivf_transformer_with_quantizer( centroids: FixedSizeListArray, metric_type: MetricType, vector_column: &str, quantizer: Quantizer, range: Option>, -) -> Result { +) -> Result { match quantizer { - Quantizer::Flat(_) => Ok(Ivf::new_flat(centroids, metric_type, vector_column, range)), - Quantizer::Product(pq) => Ok(Ivf::with_pq( + Quantizer::Flat(_) => Ok(IvfTransformer::new_flat( + centroids, + metric_type, + vector_column, + range, + )), + Quantizer::Product(pq) => Ok(IvfTransformer::with_pq( centroids, metric_type, vector_column, pq, range, )), - Quantizer::Scalar(_) => Ok(Ivf::with_sq(centroids, metric_type, vector_column, range)), + Quantizer::Scalar(_) => Ok(IvfTransformer::with_sq( + centroids, + metric_type, + vector_column, + range, + )), } } /// IVF - IVF file partition /// #[derive(Debug)] -pub struct Ivf { +pub struct IvfTransformer { /// Centroids of a cluster algorithm, to run IVF. /// /// It is a 2-D `(num_partitions * dimension)` of floating array. @@ -84,7 +93,7 @@ pub struct Ivf { distance_type: DistanceType, } -impl Ivf { +impl IvfTransformer { /// Create a new Ivf model. pub fn new( centroids: FixedSizeListArray, @@ -115,7 +124,11 @@ impl Ivf { distance_type }; - let ivf_transform = Arc::new(IvfTransformer::new(centroids.clone(), dt, vector_column)); + let ivf_transform = Arc::new(PartitionTransformer::new( + centroids.clone(), + dt, + vector_column, + )); transforms.push(ivf_transform.clone()); if let Some(range) = range { @@ -151,8 +164,12 @@ impl Ivf { distance_type }; - let ivf_transform = Arc::new(IvfTransformer::new(centroids.clone(), mt, vector_column)); - transforms.push(ivf_transform.clone()); + let partition_transform = Arc::new(PartitionTransformer::new( + centroids.clone(), + mt, + vector_column, + )); + transforms.push(partition_transform.clone()); if let Some(range) = range { transforms.push(Arc::new(transform::PartitionFilter::new( @@ -203,8 +220,12 @@ impl Ivf { metric_type }; - let ivf_transform = Arc::new(IvfTransformer::new(centroids.clone(), mt, vector_column)); - transforms.push(ivf_transform.clone()); + let partition_transformer = Arc::new(PartitionTransformer::new( + centroids.clone(), + mt, + vector_column, + )); + transforms.push(partition_transformer.clone()); if let Some(range) = range { transforms.push(Arc::new(transform::PartitionFilter::new( @@ -213,9 +234,6 @@ impl Ivf { ))); } - // For SQ we will transofrm the vector to SQ code while building the index, - // so simply drop the vector column now. - transforms.push(Arc::new(DropColumn::new(vector_column))); Self { centroids, distance_type: metric_type, @@ -243,7 +261,7 @@ impl Ivf { } } -impl Transformer for Ivf { +impl Transformer for IvfTransformer { fn transform(&self, batch: &RecordBatch) -> Result { let mut batch = batch.clone(); for transform in self.transforms.as_slice() { diff --git a/rust/lance-index/src/vector/ivf/shuffler.rs b/rust/lance-index/src/vector/ivf/shuffler.rs index dd02ff9566..e76387b359 100644 --- a/rust/lance-index/src/vector/ivf/shuffler.rs +++ b/rust/lance-index/src/vector/ivf/shuffler.rs @@ -33,7 +33,7 @@ use object_store::path::Path; use snafu::{location, Location}; use tempfile::TempDir; -use crate::vector::ivf::Ivf; +use crate::vector::ivf::IvfTransformer; use crate::vector::transform::{KeepFiniteVectors, Transformer}; use crate::vector::PART_ID_COLUMN; @@ -70,7 +70,7 @@ fn get_temp_dir() -> Result { pub async fn shuffle_dataset( data: impl RecordBatchStream + Unpin + 'static, column: &str, - ivf: Arc, + ivf: Arc, precomputed_partitions: Option>, num_partitions: u32, shuffle_partition_batches: usize, diff --git a/rust/lance-index/src/vector/ivf/storage.rs b/rust/lance-index/src/vector/ivf/storage.rs index dc94351245..9d5a8fafad 100644 --- a/rust/lance-index/src/vector/ivf/storage.rs +++ b/rust/lance-index/src/vector/ivf/storage.rs @@ -2,13 +2,14 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::ops::Range; -use std::sync::Arc; -use arrow_array::{Array, FixedSizeListArray}; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt32Array}; use deepsize::DeepSizeOf; +use itertools::Itertools; use lance_core::{Error, Result}; use lance_file::{reader::FileReader, writer::FileWriter}; use lance_io::{traits::WriteExt, utils::read_message}; +use lance_linalg::distance::DistanceType; use lance_table::io::manifest::ManifestDescribing; use log::debug; use serde::{Deserialize, Serialize}; @@ -19,53 +20,110 @@ use crate::pb::Ivf as PbIvf; pub const IVF_METADATA_KEY: &str = "lance:ivf"; pub const IVF_PARTITION_KEY: &str = "lance:ivf:partition"; +/// Ivf Model #[derive(Debug, Clone, PartialEq)] -pub struct IvfData { - /// Centroids of the IVF indices. Can be empty. - centroids: Option>, +pub struct IvfModel { + /// Centroids of each partition. + /// + /// It is a 2-D `(num_partitions * dimension)` of vector array. + pub centroids: Option, - /// Length of each partition. - lengths: Vec, + /// Offset of each partition in the file. + pub offsets: Vec, - /// pre-computed row offset for each partition, do not persist. - partition_row_offsets: Vec, + /// Number of vectors in each partition. + pub lengths: Vec, } -impl DeepSizeOf for IvfData { +impl DeepSizeOf for IvfModel { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.centroids .as_ref() .map(|centroids| centroids.get_array_memory_size()) - .unwrap_or(0) + .unwrap_or_default() + self.lengths.deep_size_of_children(context) - + self.partition_row_offsets.deep_size_of_children(context) + + self.offsets.deep_size_of_children(context) } } -/// The IVF metadata stored in the Lance Schema -#[derive(Serialize, Deserialize, Debug)] -struct IvfMetadata { - // The file position to store the protobuf binary of IVF metadata. - pb_position: usize, -} - -impl IvfData { +impl IvfModel { pub fn empty() -> Self { Self { centroids: None, + offsets: vec![], lengths: vec![], - partition_row_offsets: vec![0], } } - pub fn with_centroids(centroids: Arc) -> Self { + pub fn new(centroids: FixedSizeListArray) -> Self { Self { centroids: Some(centroids), + offsets: vec![], lengths: vec![], - partition_row_offsets: vec![0], } } + pub fn centroid(&self, partition: usize) -> Option { + self.centroids.as_ref().map(|c| c.value(partition)) + } + + /// Ivf model dimension. + pub fn dimension(&self) -> usize { + self.centroids + .as_ref() + .map(|c| c.value_length() as usize) + .unwrap_or(0) + } + + /// Number of IVF partitions. + pub fn num_partitions(&self) -> usize { + self.centroids + .as_ref() + .map(|c| c.len()) + .unwrap_or_else(|| self.offsets.len()) + } + + pub fn partition_size(&self, part: usize) -> usize { + self.lengths[part] as usize + } + + /// Use the query vector to find `nprobes` closest partitions. + pub fn find_partitions( + &self, + query: &dyn Array, + nprobes: usize, + distance_type: DistanceType, + ) -> Result { + let internal = crate::vector::ivf::new_ivf_transformer( + self.centroids.clone().unwrap(), + distance_type, + vec![], + ); + internal.find_partitions(query, nprobes) + } + + /// Add the offset and length of one partition. + pub fn add_partition(&mut self, len: u32) { + self.offsets.push( + self.offsets.last().cloned().unwrap_or_default() + + self.lengths.last().cloned().unwrap_or_default() as usize, + ); + self.lengths.push(len); + } + + /// Add the offset and length of one partition with the given offset. + /// this is used for old index format of IVF_PQ. + pub fn add_partition_with_offset(&mut self, offset: usize, len: u32) { + self.offsets.push(offset); + self.lengths.push(len); + } + + pub fn row_range(&self, partition: usize) -> Range { + let start = self.offsets[partition]; + let end = start + self.lengths[partition] as usize; + start..end + } + pub async fn load(reader: &FileReader) -> Result { let schema = reader.schema(); let meta_str = schema.metadata.get(IVF_METADATA_KEY).ok_or(Error::Index { @@ -94,37 +152,32 @@ impl IvfData { writer.add_metadata(IVF_METADATA_KEY, &serde_json::to_string(&ivf_metadata)?); Ok(()) } +} - pub fn add_partition(&mut self, num_rows: u32) { - self.lengths.push(num_rows); - let last_offset = self.partition_row_offsets.last().copied().unwrap_or(0); - self.partition_row_offsets - .push(last_offset + num_rows as usize); - } +/// Convert IvfModel to protobuf. +impl TryFrom<&IvfModel> for PbIvf { + type Error = Error; - pub fn has_centroids(&self) -> bool { - self.centroids.is_some() - } + fn try_from(ivf: &IvfModel) -> Result { + let lengths = ivf.lengths.clone(); - pub fn num_partitions(&self) -> usize { - self.lengths.len() - } - - /// Range of the rows for one partition. - pub fn row_range(&self, partition: usize) -> Range { - let start = self.partition_row_offsets[partition]; - let end = self.partition_row_offsets[partition + 1]; - start..end + Ok(Self { + centroids: vec![], // Deprecated + lengths, + offsets: ivf.offsets.iter().map(|x| *x as u64).collect(), + centroids_tensor: ivf.centroids.as_ref().map(|c| c.try_into()).transpose()?, + }) } } -impl TryFrom for IvfData { +/// Convert IvfModel to protobuf. +impl TryFrom for IvfModel { type Error = Error; fn try_from(proto: PbIvf) -> Result { let centroids = if let Some(tensor) = proto.centroids_tensor.as_ref() { debug!("Ivf: loading IVF centroids from index format v2"); - Some(Arc::new(FixedSizeListArray::try_from(tensor)?)) + Some(FixedSizeListArray::try_from(tensor)?) } else { None }; @@ -132,42 +185,38 @@ impl TryFrom for IvfData { // v1 index format. It will be deprecated soon. // // This new offset uses the row offset in the lance file. - let offsets = [0] - .iter() - .chain(proto.lengths.iter()) - .scan(0_usize, |state, &x| { - *state += x as usize; - Some(*state) - }); + let offsets = match proto.offsets.len() { + 0 => proto + .lengths + .iter() + .scan(0_usize, |state, &x| { + let old = *state; + *state += x as usize; + Some(old) + }) + .collect_vec(), + _ => proto.offsets.iter().map(|x| *x as usize).collect(), + }; + assert_eq!(offsets.len(), proto.lengths.len()); Ok(Self { centroids, + offsets, lengths: proto.lengths.clone(), - partition_row_offsets: offsets.collect(), }) } } -impl TryFrom<&IvfData> for PbIvf { - type Error = Error; - - fn try_from(meta: &IvfData) -> Result { - let lengths = meta.lengths.clone(); - - Ok(Self { - centroids: vec![], // Deprecated - lengths, - offsets: vec![], // Deprecated - centroids_tensor: meta - .centroids - .as_ref() - .map(|c| c.as_ref().try_into()) - .transpose()?, - }) - } +/// The IVF metadata stored in the Lance Schema +#[derive(Serialize, Deserialize, Debug)] +struct IvfMetadata { + // The file position to store the protobuf binary of IVF metadata. + pb_position: usize, } #[cfg(test)] mod tests { + use std::sync::Arc; + use arrow_array::{Float32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use lance_core::datatypes::Schema; @@ -179,7 +228,7 @@ mod tests { #[test] fn test_ivf_find_rows() { - let mut ivf = IvfData::empty(); + let mut ivf = IvfModel::empty(); ivf.add_partition(20); ivf.add_partition(50); @@ -189,7 +238,7 @@ mod tests { #[tokio::test] async fn test_write_and_load() { - let mut ivf = IvfData::empty(); + let mut ivf = IvfModel::empty(); ivf.add_partition(20); ivf.add_partition(50); @@ -219,7 +268,7 @@ mod tests { .unwrap(); assert!(reader.schema().metadata.contains_key(IVF_METADATA_KEY)); - let ivf2 = IvfData::load(&reader).await.unwrap(); + let ivf2 = IvfModel::load(&reader).await.unwrap(); assert_eq!(ivf, ivf2); assert_eq!(ivf2.num_partitions(), 2); } diff --git a/rust/lance-index/src/vector/ivf/transform.rs b/rust/lance-index/src/vector/ivf/transform.rs index 81bd5e376f..d05bb8d090 100644 --- a/rust/lance-index/src/vector/ivf/transform.rs +++ b/rust/lance-index/src/vector/ivf/transform.rs @@ -30,14 +30,14 @@ use super::PART_ID_COLUMN; /// this transform is a Noop. /// #[derive(Debug)] -pub struct IvfTransformer { +pub struct PartitionTransformer { centroids: FixedSizeListArray, distance_type: DistanceType, input_column: String, output_column: String, } -impl IvfTransformer { +impl PartitionTransformer { pub fn new( centroids: FixedSizeListArray, distance_type: DistanceType, @@ -60,7 +60,7 @@ impl IvfTransformer { .into() } } -impl Transformer for IvfTransformer { +impl Transformer for PartitionTransformer { fn transform(&self, batch: &RecordBatch) -> Result { if batch.column_by_name(&self.output_column).is_some() { // If the partition ID column is already present, we don't need to compute it again. diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 87b0fcd19e..4171fc40da 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -32,6 +32,7 @@ use serde::{Deserialize, Serialize}; use snafu::{location, Location}; use super::{distance::build_distance_table_l2, num_centroids, ProductQuantizerImpl}; +use crate::vector::storage::STORAGE_METADATA_KEY; use crate::{ pb, vector::{ @@ -398,7 +399,7 @@ impl VectorStore for ProductQuantizationStorage { let metadata_json = batch .schema_ref() .metadata() - .get("metadata") + .get(STORAGE_METADATA_KEY) .ok_or(Error::Index { message: "Metadata not found in schema".to_string(), location: location!(), @@ -456,7 +457,7 @@ impl VectorStore for ProductQuantizationStorage { }; let metadata_json = serde_json::to_string(&metadata)?; - let metadata = HashMap::from_iter(vec![("metadata".to_string(), metadata_json)]); + let metadata = HashMap::from_iter(vec![(STORAGE_METADATA_KEY.to_string(), metadata_json)]); let schema = self .batch diff --git a/rust/lance-index/src/vector/quantizer.rs b/rust/lance-index/src/vector/quantizer.rs index 550ec463bb..83ea913d54 100644 --- a/rust/lance-index/src/vector/quantizer.rs +++ b/rust/lance-index/src/vector/quantizer.rs @@ -2,10 +2,13 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use core::fmt; +use std::fmt::Debug; use std::sync::Arc; -use arrow::datatypes::Float32Type; +use arrow::array::AsArray; +use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array}; +use arrow_schema::DataType; use async_trait::async_trait; use deepsize::DeepSizeOf; use lance_arrow::ArrowFloatType; @@ -22,9 +25,10 @@ use crate::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; use super::flat::index::FlatQuantizer; use super::pq::storage::PQ_METADTA_KEY; use super::pq::ProductQuantizer; +use super::sq::builder::SQBuildParams; use super::sq::storage::SQ_METADATA_KEY; use super::{ - ivf::storage::IvfData, + ivf::storage::IvfModel, pq::{ storage::{ProductQuantizationMetadata, ProductQuantizationStorage}, ProductQuantizerImpl, @@ -37,15 +41,21 @@ use super::{ }; use super::{PQ_CODE_COLUMN, SQ_CODE_COLUMN}; -pub trait Quantization: Send + Sync + DeepSizeOf + Into { +pub trait Quantization: Send + Sync + Debug + DeepSizeOf + Into { + type BuildParams: QuantizerBuildParams; type Metadata: QuantizerMetadata + Send + Sync; - type Storage: QuantizerStorage + VectorStore; + type Storage: QuantizerStorage + VectorStore + Debug; + fn build( + data: &dyn Array, + distance_type: DistanceType, + params: &Self::BuildParams, + ) -> Result; fn code_dim(&self) -> usize; fn column(&self) -> &'static str; fn quantize(&self, vectors: &dyn Array) -> Result; fn metadata_key() -> &'static str; - fn quantization_type(&self) -> QuantizationType; + fn quantization_type() -> QuantizationType; fn metadata(&self, _: Option) -> Result; fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result; } @@ -66,6 +76,16 @@ impl std::fmt::Display for QuantizationType { } } +pub trait QuantizerBuildParams { + fn sample_size(&self) -> usize; +} + +impl QuantizerBuildParams for () { + fn sample_size(&self) -> usize { + 0 + } +} + /// Quantization Method. /// ///
@@ -105,9 +125,9 @@ impl Quantizer { pub fn quantization_type(&self) -> QuantizationType { match self { - Self::Flat(fq) => fq.quantization_type(), - Self::Product(pq) => pq.quantization_type(), - Self::Scalar(sq) => sq.quantization_type(), + Self::Flat(_) => QuantizationType::Flat, + Self::Product(_) => QuantizationType::Product, + Self::Scalar(_) => QuantizationType::Scalar, } } @@ -168,9 +188,42 @@ pub trait QuantizerStorage: Clone + Sized + DeepSizeOf + VectorStore { } impl Quantization for ScalarQuantizer { + type BuildParams = SQBuildParams; type Metadata = ScalarQuantizationMetadata; type Storage = ScalarQuantizationStorage; + fn build(data: &dyn Array, _: DistanceType, params: &Self::BuildParams) -> Result { + let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index { + message: format!( + "SQ builder: input is not a FixedSizeList: {}", + data.data_type() + ), + location: location!(), + })?; + + let mut quantizer = Self::new(params.num_bits, fsl.value_length() as usize); + + match fsl.value_type() { + DataType::Float16 => { + quantizer.update_bounds::(fsl)?; + } + DataType::Float32 => { + quantizer.update_bounds::(fsl)?; + } + DataType::Float64 => { + quantizer.update_bounds::(fsl)?; + } + _ => { + return Err(Error::Index { + message: format!("SQ builder: unsupported data type: {}", fsl.value_type()), + location: location!(), + }) + } + } + + Ok(quantizer) + } + fn code_dim(&self) -> usize { self.dim } @@ -180,15 +233,22 @@ impl Quantization for ScalarQuantizer { } fn quantize(&self, vectors: &dyn Array) -> Result { - let code_array = self.transform::(vectors)?; - Ok(code_array) + match vectors.as_fixed_size_list().value_type() { + DataType::Float16 => self.transform::(vectors), + DataType::Float32 => self.transform::(vectors), + DataType::Float64 => self.transform::(vectors), + value_type => Err(Error::invalid_input( + format!("unsupported data type {} for scalar quantizer", value_type), + location!(), + )), + } } fn metadata_key() -> &'static str { SQ_METADATA_KEY } - fn quantization_type(&self) -> QuantizationType { + fn quantization_type() -> QuantizationType { QuantizationType::Scalar } @@ -210,9 +270,14 @@ impl Quantization for ScalarQuantizer { } impl Quantization for Arc { + type BuildParams = (); type Metadata = ProductQuantizationMetadata; type Storage = ProductQuantizationStorage; + fn build(_: &dyn Array, _: DistanceType, _: &Self::BuildParams) -> Result { + unimplemented!("ProductQuantizer cannot be built with new index builder") + } + fn code_dim(&self) -> usize { self.num_sub_vectors() } @@ -230,7 +295,7 @@ impl Quantization for Arc { PQ_METADTA_KEY } - fn quantization_type(&self) -> QuantizationType { + fn quantization_type() -> QuantizationType { QuantizationType::Product } @@ -278,9 +343,14 @@ impl Quantization for ProductQuantizerImpl where T::Native: Dot + L2, { + type BuildParams = (); type Metadata = ProductQuantizationMetadata; type Storage = ProductQuantizationStorage; + fn build(_: &dyn Array, _: DistanceType, _: &Self::BuildParams) -> Result { + unimplemented!("ProductQuantizer cannot be built with new index builder") + } + fn code_dim(&self) -> usize { self.num_sub_vectors() } @@ -298,7 +368,7 @@ where PQ_METADTA_KEY } - fn quantization_type(&self) -> QuantizationType { + fn quantization_type() -> QuantizationType { QuantizationType::Product } @@ -348,7 +418,7 @@ pub struct IvfQuantizationStorage { quantizer: Quantizer, metadata: Q::Metadata, - ivf: IvfData, + ivf: IvfModel, } impl DeepSizeOf for IvfQuantizationStorage { @@ -398,7 +468,7 @@ impl IvfQuantizationStorage { })?; let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?; - let ivf_data = IvfData::load(&reader).await?; + let ivf_data = IvfModel::load(&reader).await?; let metadata = Q::Metadata::load(&reader).await?; let quantizer = Q::from_metadata(&metadata, distance_type)?; diff --git a/rust/lance-index/src/vector/sq.rs b/rust/lance-index/src/vector/sq.rs index a3c7f5238c..1ef31a9dec 100644 --- a/rust/lance-index/src/vector/sq.rs +++ b/rust/lance-index/src/vector/sq.rs @@ -13,6 +13,8 @@ use lance_core::{Error, Result}; use num_traits::*; use snafu::{location, Location}; +use super::quantizer::Quantizer; + pub mod builder; pub mod storage; pub mod transform; @@ -129,6 +131,19 @@ impl ScalarQuantizer { } } +impl TryFrom for ScalarQuantizer { + type Error = Error; + fn try_from(value: Quantizer) -> Result { + match value { + Quantizer::Scalar(sq) => Ok(sq), + _ => Err(Error::Index { + message: "Expect to be a ScalarQuantizer".to_string(), + location: location!(), + }), + } + } +} + pub(crate) fn scale_to_u8(values: &[T::Native], bounds: Range) -> Vec { let range = bounds.end - bounds.start; values diff --git a/rust/lance-index/src/vector/sq/builder.rs b/rust/lance-index/src/vector/sq/builder.rs index fb885e2cd2..913751062c 100644 --- a/rust/lance-index/src/vector/sq/builder.rs +++ b/rust/lance-index/src/vector/sq/builder.rs @@ -1,18 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use arrow::{ - array::AsArray, - datatypes::{Float16Type, Float32Type, Float64Type}, -}; -use arrow_array::Array; - -use arrow_schema::DataType; -use lance_core::{Error, Result}; -use lance_linalg::distance::DistanceType; -use snafu::{location, Location}; - -use super::ScalarQuantizer; +use crate::vector::quantizer::QuantizerBuildParams; #[derive(Debug, Clone)] pub struct SQBuildParams { @@ -32,36 +21,8 @@ impl Default for SQBuildParams { } } -impl SQBuildParams { - pub fn build(&self, data: &dyn Array, _: DistanceType) -> Result { - let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index { - message: format!( - "SQ builder: input is not a FixedSizeList: {}", - data.data_type() - ), - location: location!(), - })?; - - let mut quantizer = ScalarQuantizer::new(self.num_bits, fsl.value_length() as usize); - - match fsl.value_type() { - DataType::Float16 => { - quantizer.update_bounds::(fsl)?; - } - DataType::Float32 => { - quantizer.update_bounds::(fsl)?; - } - DataType::Float64 => { - quantizer.update_bounds::(fsl)?; - } - _ => { - return Err(Error::Index { - message: format!("SQ builder: unsupported data type: {}", fsl.value_type()), - location: location!(), - }) - } - } - - Ok(quantizer) +impl QuantizerBuildParams for SQBuildParams { + fn sample_size(&self) -> usize { + self.sample_rate * 2usize.pow(self.num_bits as u32) } } diff --git a/rust/lance-index/src/vector/sq/storage.rs b/rust/lance-index/src/vector/sq/storage.rs index 52fa3ff84f..beb83f5851 100644 --- a/rust/lance-index/src/vector/sq/storage.rs +++ b/rust/lance-index/src/vector/sq/storage.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::{collections::HashMap, ops::Range, sync::Arc}; +use std::ops::Range; use arrow::compute::concat_batches; use arrow_array::{ @@ -21,6 +21,7 @@ use object_store::path::Path; use serde::{Deserialize, Serialize}; use snafu::{location, Location}; +use crate::vector::storage::STORAGE_METADATA_KEY; use crate::{ vector::{ quantizer::{QuantizerMetadata, QuantizerStorage}, @@ -70,7 +71,7 @@ impl QuantizerMetadata for ScalarQuantizationMetadata { } /// An immutable chunk of SclarQuantizationStorage. -#[derive(Clone)] +#[derive(Debug, Clone)] struct SQStorageChunk { batch: RecordBatch, @@ -150,7 +151,7 @@ impl DeepSizeOf for SQStorageChunk { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct ScalarQuantizationStorage { quantizer: ScalarQuantizer, @@ -291,7 +292,7 @@ impl VectorStore for ScalarQuantizationStorage { let metadata_json = batch .schema_ref() .metadata() - .get("metadata") + .get(STORAGE_METADATA_KEY) .ok_or(Error::Schema { message: "metadata not found".to_string(), location: location!(), @@ -302,27 +303,7 @@ impl VectorStore for ScalarQuantizationStorage { } fn to_batches(&self) -> Result> { - let metadata = ScalarQuantizationMetadata { - dim: self.chunks[0].dim(), - num_bits: self.quantizer.num_bits, - bounds: self.quantizer.bounds.clone(), - }; - let metadata_json = serde_json::to_string(&metadata)?; - let metadata = HashMap::from_iter(vec![("metadata".to_owned(), metadata_json)]); - - let schema = self.chunks[0] - .schema() - .as_ref() - .clone() - .with_metadata(metadata); - let schema = Arc::new(schema); - Ok(self.chunks.iter().map(move |chunk| { - chunk - .batch - .clone() - .with_schema(schema.clone()) - .expect("attach schema") - })) + Ok(self.chunks.iter().map(|c| c.batch.clone())) } fn append_batch(&self, batch: RecordBatch, vector_column: &str) -> Result { @@ -458,6 +439,7 @@ mod tests { use super::*; use std::iter::repeat_with; + use std::sync::Arc; use arrow_array::FixedSizeListArray; use arrow_schema::{DataType, Field, Schema}; diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index ae5834de8a..d11a132801 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -22,12 +22,13 @@ use snafu::{location, Location}; use crate::{ pb, vector::{ - ivf::storage::{IvfData, IVF_METADATA_KEY}, - quantizer::{Quantization, Quantizer}, + ivf::storage::{IvfModel, IVF_METADATA_KEY}, + quantizer::Quantization, }, INDEX_METADATA_SCHEMA_KEY, }; +use super::quantizer::Quantizer; use super::DISTANCE_TYPE_KEY; ///
@@ -40,6 +41,8 @@ pub trait DistCalculator { fn prefetch(&self, _id: u32) {} } +pub const STORAGE_METADATA_KEY: &str = "storage_metadata"; + /// Vector Storage is the abstraction to store the vectors. /// /// It can be in-memory or on-disk, raw vector or quantized vectors. @@ -121,13 +124,19 @@ impl StorageBuilder { location: location!(), })?; let code_array = self.quantizer.quantize(vectors.as_ref())?; - let batch = batch.drop_column(&self.column)?.try_with_column( - Field::new( - self.quantizer.column(), - code_array.data_type().clone(), - true, - ), - code_array, + let batch = batch + .try_with_column( + Field::new( + self.quantizer.column(), + code_array.data_type().clone(), + true, + ), + code_array, + )? + .drop_column(&self.column)?; + let batch = batch.add_metadata( + STORAGE_METADATA_KEY.to_owned(), + self.quantizer.metadata(None)?.to_string(), )?; Q::Storage::try_from_batch(batch, self.distance_type) } @@ -135,30 +144,27 @@ impl StorageBuilder { /// Loader to load partitioned PQ storage from disk. #[derive(Debug)] -pub struct IvfQuantizationStorage { +pub struct IvfQuantizationStorage { reader: FileReader, distance_type: DistanceType, - quantizer: Quantizer, - metadata: Q::Metadata, + metadata: Vec, - ivf: IvfData, + ivf: IvfModel, } -impl DeepSizeOf for IvfQuantizationStorage { +impl DeepSizeOf for IvfQuantizationStorage { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { - self.quantizer.deep_size_of_children(context) - + self.metadata.deep_size_of_children(context) - + self.ivf.deep_size_of_children(context) + self.metadata.deep_size_of_children(context) + self.ivf.deep_size_of_children(context) } } #[allow(dead_code)] -impl IvfQuantizationStorage { +impl IvfQuantizationStorage { /// Open a Loader. /// /// - pub async fn open(reader: FileReader) -> Result { + pub async fn try_new(reader: FileReader) -> Result { let schema = reader.schema(); let distance_type = DistanceType::try_from( @@ -185,34 +191,29 @@ impl IvfQuantizationStorage { location: location!(), })?; let ivf_bytes = reader.read_global_buffer(ivf_pos).await?; - let ivf = IvfData::try_from(pb::Ivf::decode(ivf_bytes)?)?; + let ivf = IvfModel::try_from(pb::Ivf::decode(ivf_bytes)?)?; - let quantizer_metadata: Q::Metadata = serde_json::from_str( + let metadata: Vec = serde_json::from_str( schema .metadata - .get(Q::metadata_key()) + .get(STORAGE_METADATA_KEY) .ok_or(Error::Index { - message: format!("{} not found", Q::metadata_key()), + message: format!("{} not found", STORAGE_METADATA_KEY), location: location!(), })? .as_str(), )?; - let quantizer = Q::from_metadata(&quantizer_metadata, distance_type)?; Ok(Self { reader, distance_type, - quantizer, - metadata: quantizer_metadata, + metadata, ivf, }) } - pub fn quantizer(&self) -> &Quantizer { - &self.quantizer - } - - pub fn metadata(&self) -> &Q::Metadata { - &self.metadata + pub fn quantizer(&self) -> Result { + let metadata = serde_json::from_str(&self.metadata[0])?; + Q::from_metadata(&metadata, self.distance_type) } /// Get the number of partitions in the storage. @@ -220,7 +221,7 @@ impl IvfQuantizationStorage { self.ivf.num_partitions() } - pub async fn load_partition(&self, part_id: usize) -> Result { + pub async fn load_partition(&self, part_id: usize) -> Result { let range = self.ivf.row_range(part_id); let batches = self .reader @@ -234,6 +235,10 @@ impl IvfQuantizationStorage { .await?; let schema = Arc::new(self.reader.schema().as_ref().into()); let batch = concat_batches(&schema, batches.iter())?; + let batch = batch.add_metadata( + STORAGE_METADATA_KEY.to_owned(), + self.metadata[part_id].clone(), + )?; Q::Storage::try_from_batch(batch, self.distance_type) } } diff --git a/rust/lance-index/src/vector/v3/shuffler.rs b/rust/lance-index/src/vector/v3/shuffler.rs index 6377526911..16d1911d37 100644 --- a/rust/lance-index/src/vector/v3/shuffler.rs +++ b/rust/lance-index/src/vector/v3/shuffler.rs @@ -35,7 +35,7 @@ pub trait ShuffleReader: Send + Sync { ) -> Result>>; /// Get the size of the partition by partition_id - fn partiton_size(&self, partition_id: usize) -> Result; + fn partition_size(&self, partition_id: usize) -> Result; } #[async_trait::async_trait] @@ -51,7 +51,7 @@ pub trait Shuffler: Send + Sync { } pub struct IvfShuffler { - object_store: ObjectStore, + object_store: Arc, output_dir: Path, num_partitions: usize, @@ -60,9 +60,9 @@ pub struct IvfShuffler { } impl IvfShuffler { - pub fn new(object_store: ObjectStore, output_dir: Path, num_partitions: usize) -> Self { + pub fn new(output_dir: Path, num_partitions: usize) -> Self { Self { - object_store, + object_store: Arc::new(ObjectStore::local()), output_dir, num_partitions, buffer_size: 4096, @@ -85,8 +85,6 @@ impl Shuffler for IvfShuffler { let mut partition_sizes = vec![0; self.num_partitions]; let mut first_pass = true; - let mut counter = 0; - let num_partitions = self.num_partitions; let mut parallel_sort_stream = data .map(|batch| { @@ -133,8 +131,8 @@ impl Shuffler for IvfShuffler { .map(|_| Vec::new()) .collect::>(); + let mut counter = 0; while let Some(shuffled) = parallel_sort_stream.next().await { - log::info!("shuffle batch: {}", counter); let shuffled = shuffled?; for (part_id, batches) in shuffled.into_iter().enumerate() { @@ -176,6 +174,7 @@ impl Shuffler for IvfShuffler { // do flush if counter % self.buffer_size == 0 { + log::info!("shuffle {} batches, flushing", counter); let mut futs = vec![]; for (part_id, writer) in writers.iter_mut().enumerate() { let batches = &partition_buffers[part_id]; @@ -208,27 +207,41 @@ impl Shuffler for IvfShuffler { writer.finish().await?; } - Ok(Box::new(IvfShufflerReader { - object_store: self.object_store.clone(), - output_dir: self.output_dir.clone(), + Ok(Box::new(IvfShufflerReader::new( + self.object_store.clone(), + self.output_dir.clone(), partition_sizes, - })) + ))) } } pub struct IvfShufflerReader { - object_store: ObjectStore, + object_store: Arc, output_dir: Path, partition_sizes: Vec, } +impl IvfShufflerReader { + pub fn new( + object_store: Arc, + output_dir: Path, + partition_sizes: Vec, + ) -> Self { + Self { + object_store, + output_dir, + partition_sizes, + } + } +} + #[async_trait::async_trait] impl ShuffleReader for IvfShufflerReader { async fn read_partition( &self, partition_id: usize, ) -> Result>> { - let scheduler = ScanScheduler::new(Arc::new(self.object_store.clone()), 32); + let scheduler = ScanScheduler::new(self.object_store.clone(), 32); let partition_path = self.output_dir.child(format!("ivf_{}.lance", partition_id)); let reader = FileReader::try_open( @@ -250,7 +263,7 @@ impl ShuffleReader for IvfShufflerReader { )))) } - fn partiton_size(&self, partition_id: usize) -> Result { + fn partition_size(&self, partition_id: usize) -> Result { Ok(self.partition_sizes[partition_id]) } } diff --git a/rust/lance-index/src/vector/v3/subindex.rs b/rust/lance-index/src/vector/v3/subindex.rs index 933af8f7f4..cd795e9599 100644 --- a/rust/lance-index/src/vector/v3/subindex.rs +++ b/rust/lance-index/src/vector/v3/subindex.rs @@ -1,19 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::fmt::Debug; use std::sync::Arc; use arrow_array::{ArrayRef, RecordBatch}; use deepsize::DeepSizeOf; -use lance_core::Result; +use lance_core::{Error, Result}; +use snafu::{location, Location}; use crate::vector::storage::VectorStore; +use crate::vector::{flat, hnsw}; use crate::{prefilter::PreFilter, vector::Query}; - -pub const SUB_INDEX_METADATA_KEY: &str = "sub_index_metadata"; - /// A sub index for IVF index -pub trait IvfSubIndex: Send + Sync + DeepSizeOf { +pub trait IvfSubIndex: Send + Sync + Debug + DeepSizeOf { type QueryParams: Send + Sync + for<'a> From<&'a Query>; type BuildParams: Clone; @@ -26,6 +26,8 @@ pub trait IvfSubIndex: Send + Sync + DeepSizeOf { fn name() -> &'static str; + fn metadata_key() -> &'static str; + /// Return the schema of the sub index fn schema() -> arrow_schema::SchemaRef; @@ -52,3 +54,32 @@ pub trait IvfSubIndex: Send + Sync + DeepSizeOf { /// Encode the sub index into a record batch fn to_batch(&self) -> Result; } + +pub enum SubIndexType { + Flat, + Hnsw, +} + +impl std::fmt::Display for SubIndexType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Flat => write!(f, "{}", flat::index::FlatIndex::name()), + Self::Hnsw => write!(f, "{}", hnsw::builder::HNSW::name()), + } + } +} + +impl TryFrom<&str> for SubIndexType { + type Error = Error; + + fn try_from(value: &str) -> Result { + match value { + "FLAT" => Ok(Self::Flat), + "HNSW" => Ok(Self::Hnsw), + _ => Err(Error::Index { + message: format!("unknown sub index type {}", value), + location: location!(), + }), + } + } +} diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index e4886ad312..d536b6a936 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -12,14 +12,19 @@ use async_trait::async_trait; use futures::{stream, StreamExt, TryStreamExt}; use itertools::Itertools; use lance_file::reader::FileReader; +use lance_file::v2; use lance_index::optimize::OptimizeOptions; use lance_index::pb::index::Implementation; use lance_index::scalar::expression::IndexInformationProvider; use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::scalar::ScalarIndex; use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::hnsw::HNSW; +use lance_index::vector::sq::ScalarQuantizer; pub use lance_index::IndexParams; +use lance_index::INDEX_METADATA_SCHEMA_KEY; use lance_index::{pb, vector::VectorIndex, DatasetIndexExt, Index, IndexType, INDEX_FILE_NAME}; +use lance_io::scheduler::ScanScheduler; use lance_io::traits::Reader; use lance_io::utils::{ read_last_block, read_message, read_message_from_buf, read_metadata_offset, read_version, @@ -519,7 +524,7 @@ impl DatasetIndexInternalExt for Dataset { // the index file is in lance format since version (0,2) // TODO: we need to change the legacy IVF_PQ to be in lance format - match (major_version, minor_version) { + let index = match (major_version, minor_version) { (0, 1) | (0, 0) => { let proto = open_index_proto(reader.as_ref()).await?; match &proto.implementation { @@ -557,14 +562,70 @@ impl DatasetIndexInternalExt for Dataset { } (0, 3) => { - let ivf = IVFIndex::::try_new( - self.object_store.clone(), - self.indices_dir(), - uuid.to_owned(), - Arc::downgrade(&self.session), - ) - .await?; - Ok(Arc::new(ivf)) + let scheduler = ScanScheduler::new(self.object_store.clone(), 16); + let file = scheduler.open_file(&index_file).await?; + let reader = + v2::reader::FileReader::try_open(file, None, Default::default()).await?; + let index_metadata = reader + .schema() + .metadata + .get(INDEX_METADATA_SCHEMA_KEY) + .ok_or(Error::Index { + message: "Index Metadata not found".to_owned(), + location: location!(), + })?; + let index_metadata: lance_index::IndexMetadata = + serde_json::from_str(index_metadata)?; + let field = self.schema().field(column).ok_or_else(|| Error::Index { + message: format!("Column {} does not exist in the schema", column), + location: location!(), + })?; + + let value_type = if let DataType::FixedSizeList(df, _) = field.data_type() { + Result::Ok(df.data_type().to_owned()) + } else { + return Err(Error::Index { + message: format!("Column {} is not a vector column", column), + location: location!(), + }); + }?; + match index_metadata.index_type.as_str() { + "FLAT" => match value_type { + DataType::Float16 | DataType::Float32 | DataType::Float64 => { + let ivf = IVFIndex::::try_new( + self.object_store.clone(), + self.indices_dir(), + uuid.to_owned(), + Arc::downgrade(&self.session), + ) + .await?; + Ok(Arc::new(ivf) as Arc) + } + _ => Err(Error::Index { + message: format!( + "the field type {} is not supported for FLAT index", + field.data_type() + ), + location: location!(), + }), + }, + + "HNSW" => { + let ivf = IVFIndex::::try_new( + self.object_store.clone(), + self.indices_dir(), + uuid.to_owned(), + Arc::downgrade(&self.session), + ) + .await?; + Ok(Arc::new(ivf) as Arc) + } + + _ => Err(Error::Index { + message: format!("Unsupported index type: {}", index_metadata.index_type), + location: location!(), + }), + } } _ => Err(Error::Index { @@ -572,7 +633,10 @@ impl DatasetIndexInternalExt for Dataset { .to_owned(), location: location!(), }), - } + }; + let index = index?; + self.session.index_cache.insert_vector(uuid, index.clone()); + Ok(index) } async fn scalar_index_info(&self) -> Result { diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 19aecf2424..6e32234042 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -10,7 +10,6 @@ use std::{any::Any, collections::HashMap}; pub mod builder; pub mod ivf; pub mod pq; -pub mod sq; mod traits; mod utils; @@ -21,7 +20,8 @@ use arrow::datatypes::Float32Type; use builder::IvfIndexBuilder; use lance_file::reader::FileReader; use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; -use lance_index::vector::ivf::storage::IvfData; +use lance_index::vector::hnsw::HNSW; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::ProductQuantizerImpl; use lance_index::vector::v3::shuffler::IvfShuffler; use lance_index::vector::{ @@ -40,7 +40,6 @@ use lance_linalg::distance::*; use lance_table::format::Index as IndexMetadata; use snafu::{location, Location}; use tracing::instrument; -use utils::get_vector_dim; use uuid::Uuid; use self::{ivf::*, pq::PQIndex}; @@ -79,20 +78,6 @@ impl VectorIndexParams { } } - // IVF_HNSW index - pub fn ivf_hnsw( - num_partitions: usize, - distance_type: DistanceType, - hnsw_params: HnswBuildParams, - ) -> Self { - let ivf_params = IvfBuildParams::new(num_partitions); - let stages = vec![StageParams::Ivf(ivf_params), StageParams::Hnsw(hnsw_params)]; - Self { - stages, - metric_type: distance_type, - } - } - /// Create index parameters for `IVF_PQ` index. /// /// Parameters @@ -146,11 +131,6 @@ impl VectorIndexParams { hnsw: HnswBuildParams, pq: PQBuildParams, ) -> Self { - let hnsw = match &hnsw.parallel_limit { - Some(_) => hnsw, - None => hnsw.parallel_limit(num_cpus::get().div_ceil(ivf.num_partitions)), - }; - let stages = vec![ StageParams::Ivf(ivf), StageParams::Hnsw(hnsw), @@ -170,11 +150,6 @@ impl VectorIndexParams { hnsw: HnswBuildParams, sq: SQBuildParams, ) -> Self { - let hnsw = match &hnsw.parallel_limit { - Some(_) => hnsw, - None => hnsw.parallel_limit(num_cpus::get().div_ceil(ivf.num_partitions)), - }; - let stages = vec![ StageParams::Ivf(ivf), StageParams::Hnsw(hnsw), @@ -237,7 +212,6 @@ pub(crate) async fn build_vector_index( params: &VectorIndexParams, ) -> Result<()> { let stages = ¶ms.stages; - let dim = get_vector_dim(dataset, column)?; if stages.is_empty() { return Err(Error::Index { @@ -255,21 +229,16 @@ pub(crate) async fn build_vector_index( }; let temp_dir = tempfile::tempdir()?; let path = temp_dir.path().to_str().unwrap().into(); - let shuffler = IvfShuffler::new( - dataset.object_store().clone(), - path, - ivf_params.num_partitions, - ); - let quantizer = FlatQuantizer::new(dim, params.metric_type); - IvfIndexBuilder::::new( + let shuffler = IvfShuffler::new(path, ivf_params.num_partitions); + IvfIndexBuilder::::new( dataset.clone(), column.to_owned(), dataset.indices_dir().child(uuid), params.metric_type, Box::new(shuffler), - ivf_params.clone(), + Some(ivf_params.clone()), + Some(()), (), - quantizer, )? .build() .await?; @@ -313,6 +282,10 @@ pub(crate) async fn build_vector_index( }); }; + let temp_dir = tempfile::tempdir()?; + let path = temp_dir.path().to_str().unwrap().into(); + let shuffler = IvfShuffler::new(path, ivf_params.num_partitions); + // with quantization if len > 2 { match stages.last().unwrap() { @@ -330,17 +303,18 @@ pub(crate) async fn build_vector_index( .await? } StageParams::SQ(sq_params) => { - build_ivf_hnsw_sq_index( - dataset, - column, - name, - uuid, + IvfIndexBuilder::::new( + dataset.clone(), + column.to_owned(), + dataset.indices_dir().child(uuid), params.metric_type, - ivf_params, - hnsw_params, - sq_params, - ) - .await? + Box::new(shuffler), + Some(ivf_params.clone()), + Some(sq_params.clone()), + hnsw_params.clone(), + )? + .build() + .await?; } _ => { return Err(Error::Index { @@ -431,7 +405,7 @@ pub(crate) async fn open_vector_index( location: location!(), }); } - let ivf = Ivf::try_from(ivf_pb)?; + let ivf = IvfModel::try_from(ivf_pb.to_owned())?; last_stage = Some(Arc::new(IVFIndex::try_new( dataset.session.clone(), uuid, @@ -468,7 +442,6 @@ pub(crate) async fn open_vector_index( }); } let idx = last_stage.unwrap(); - dataset.session.index_cache.insert_vector(uuid, idx.clone()); Ok(idx) } @@ -498,7 +471,7 @@ pub(crate) async fn open_vector_index_v2( .child(INDEX_AUXILIARY_FILE_NAME); let aux_reader = dataset.object_store().open(&aux_path).await?; - let ivf_data = IvfData::load(&reader).await?; + let ivf_data = IvfModel::load(&reader).await?; let options = HNSWIndexOptions { use_residual: true }; let hnsw = HNSWIndex::>::try_new( reader.object_reader.clone(), @@ -507,7 +480,7 @@ pub(crate) async fn open_vector_index_v2( ) .await?; let pb_ivf = pb::Ivf::try_from(&ivf_data)?; - let ivf = Ivf::try_from(&pb_ivf)?; + let ivf = IvfModel::try_from(pb_ivf)?; Arc::new(IVFIndex::try_new( dataset.session.clone(), @@ -526,7 +499,7 @@ pub(crate) async fn open_vector_index_v2( .child(INDEX_AUXILIARY_FILE_NAME); let aux_reader = dataset.object_store().open(&aux_path).await?; - let ivf_data = IvfData::load(&reader).await?; + let ivf_data = IvfModel::load(&reader).await?; let options = HNSWIndexOptions { use_residual: false, }; @@ -538,7 +511,7 @@ pub(crate) async fn open_vector_index_v2( ) .await?; let pb_ivf = pb::Ivf::try_from(&ivf_data)?; - let ivf = Ivf::try_from(&pb_ivf)?; + let ivf = IvfModel::try_from(pb_ivf)?; Arc::new(IVFIndex::try_new( dataset.session.clone(), @@ -573,10 +546,5 @@ pub(crate) async fn open_vector_index_v2( } }; - dataset - .session - .index_cache - .insert_vector(uuid, index.clone()); - Ok(index) } diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 6d3c194b25..536c1dda66 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -3,19 +3,24 @@ use std::sync::Arc; -use arrow_array::{FixedSizeListArray, RecordBatch}; +use arrow::array::AsArray; +use arrow_array::{RecordBatch, UInt64Array}; use futures::prelude::stream::{StreamExt, TryStreamExt}; use itertools::Itertools; -use lance_core::{Error, Result}; +use lance_arrow::RecordBatchExt; +use lance_core::{Error, Result, ROW_ID_FIELD}; use lance_encoding::decoder::{DecoderMiddlewareChain, FilterExpression}; use lance_file::v2::{reader::FileReader, writer::FileWriter}; +use lance_index::vector::flat::storage::FlatStorage; +use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::quantizer::QuantizerBuildParams; +use lance_index::vector::storage::STORAGE_METADATA_KEY; +use lance_index::vector::v3::shuffler::IvfShufflerReader; +use lance_index::vector::VectorIndex; use lance_index::{ pb, vector::{ - ivf::{ - storage::{IvfData, IVF_METADATA_KEY}, - IvfBuildParams, - }, + ivf::{storage::IVF_METADATA_KEY, IvfBuildParams}, quantizer::Quantization, storage::{StorageBuilder, VectorStore}, transform::Transformer, @@ -27,11 +32,14 @@ use lance_index::{ }, INDEX_AUXILIARY_FILE_NAME, INDEX_FILE_NAME, }; +use lance_index::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; +use lance_io::stream::RecordBatchStream; use lance_io::{ object_store::ObjectStore, scheduler::ScanScheduler, stream::RecordBatchStreamAdapter, ReadBatchParams, }; use lance_linalg::distance::DistanceType; +use log::info; use object_store::path::Path; use prost::Message; use snafu::{location, Location}; @@ -39,21 +47,39 @@ use tempfile::TempDir; use crate::Dataset; -use super::{utils, Ivf}; +use super::utils; +use super::v2::IVFIndex; +// Builder for IVF index +// The builder will train the IVF model and quantizer, shuffle the dataset, and build the sub index +// for each partition. +// To build the index for the whole dataset, call `build` method. +// To build the index for given IVF, quantizer, data stream, +// call `with_ivf`, `with_quantizer`, `shuffle_data`, and `build` in order. pub struct IvfIndexBuilder { dataset: Dataset, column: String, index_dir: Path, distance_type: DistanceType, shuffler: Arc, - ivf_params: IvfBuildParams, + // build params, only needed for building new IVF, quantizer + ivf_params: Option, + quantizer_params: Option, sub_index_params: S::BuildParams, - quantizer: Q, + _temp_dir: TempDir, // store this for keeping the temp dir alive temp_dir: Path, + + // fields will be set during build + ivf: Option, + quantizer: Option, + shuffle_reader: Option>, + partition_sizes: Vec<(usize, usize)>, + + // fields for merging indices + existing_indices: Vec>, } -impl IvfIndexBuilder { +impl IvfIndexBuilder { #[allow(clippy::too_many_arguments)] pub fn new( dataset: Dataset, @@ -61,12 +87,12 @@ impl IvfIndexBuilder { index_dir: Path, distance_type: DistanceType, shuffler: Box, - ivf_params: IvfBuildParams, + ivf_params: Option, + quantizer_params: Option, sub_index_params: S::BuildParams, - quantizer: Q, ) -> Result { let temp_dir = TempDir::new()?; - let temp_dir = Path::from(temp_dir.path().to_str().unwrap()); + let temp_dir_path = temp_dir.path().to_str().unwrap().into(); Ok(Self { dataset, column, @@ -74,74 +100,136 @@ impl IvfIndexBuilder { distance_type, shuffler: shuffler.into(), ivf_params, + quantizer_params, sub_index_params, - quantizer, - temp_dir, + _temp_dir: temp_dir, + temp_dir: temp_dir_path, + // fields will be set during build + ivf: None, + quantizer: None, + shuffle_reader: None, + partition_sizes: Vec::new(), + existing_indices: Vec::new(), }) } - pub async fn build(&self) -> Result<()> { - // step 1. train IVF - let ivf = self.load_or_build_ivf().await?; - - // step 2. shuffle data - let reader = self.shuffle_data(ivf.centroids.clone()).await?; - let partition_build_order = (0..self.ivf_params.num_partitions) - .map(|partition_id| reader.partiton_size(partition_id)) - .collect::>>()? - // sort by partition size in descending order - .into_iter() - .enumerate() - .map(|(idx, x)| (x, idx)) - .sorted() - .rev() - .map(|(_, idx)| idx) - .collect::>(); + pub fn new_incremental( + dataset: Dataset, + column: String, + index_dir: Path, + distance_type: DistanceType, + shuffler: Box, + sub_index_params: S::BuildParams, + ) -> Result { + Self::new( + dataset, + column, + index_dir, + distance_type, + shuffler, + None, + None, + sub_index_params, + ) + } - // step 3. build sub index - let mut partition_sizes = Vec::with_capacity(self.ivf_params.num_partitions); - for &partition in &partition_build_order { - let partition_data = reader.read_partition(partition).await?.ok_or(Error::io( - format!("partition {} is empty", partition).as_str(), - location!(), - ))?; - let batches = partition_data.try_collect::>().await?; - let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; + // build the index with the all data in the dataset, + pub async fn build(&mut self) -> Result<()> { + // step 1. train IVF & quantizer + if self.ivf.is_none() { + self.with_ivf(self.load_or_build_ivf().await?); + } + if self.quantizer.is_none() { + self.with_quantizer(self.load_or_build_quantizer().await?); + } - let sizes = self.build_partition(partition, &batch).await?; - partition_sizes.push(sizes); + // step 2. shuffle the dataset + if self.partition_sizes.is_empty() { + self.shuffle_dataset().await?; } + // step 3. build partitions + self.build_partitions().await?; + // step 4. merge all partitions - self.merge_partitions(ivf.centroids, partition_sizes) - .await?; + self.merge_partitions().await?; Ok(()) } - async fn load_or_build_ivf(&self) -> Result { + pub fn with_ivf(&mut self, ivf: IvfModel) -> &mut Self { + self.ivf = Some(ivf); + self + } + + pub fn with_quantizer(&mut self, quantizer: Q) -> &mut Self { + self.quantizer = Some(quantizer); + self + } + + pub fn with_existing_indices(&mut self, indices: Vec>) -> &mut Self { + self.existing_indices = indices; + self + } + + async fn load_or_build_ivf(&self) -> Result { + let ivf_params = self.ivf_params.as_ref().ok_or(Error::invalid_input( + "IVF build params not set", + location!(), + ))?; let dim = utils::get_vector_dim(&self.dataset, &self.column)?; super::build_ivf_model( &self.dataset, &self.column, dim, self.distance_type, - &self.ivf_params, + ivf_params, ) .await // TODO: load ivf model } - async fn shuffle_data(&self, centroids: FixedSizeListArray) -> Result> { - let transformer = Arc::new(lance_index::vector::ivf::new_ivf_with_quantizer( - centroids, - self.distance_type, - &self.column, - self.quantizer.clone().into(), - Some(0..self.ivf_params.num_partitions as u32), - )?); + async fn load_or_build_quantizer(&self) -> Result { + let quantizer_params = self.quantizer_params.as_ref().ok_or(Error::invalid_input( + "quantizer build params not set", + location!(), + ))?; + let sample_size_hint = quantizer_params.sample_size(); + + let start = std::time::Instant::now(); + info!( + "loading training data for quantizer. sample size: {}", + sample_size_hint + ); + let training_data = + utils::maybe_sample_training_data(&self.dataset, &self.column, sample_size_hint) + .await?; + info!( + "Finished loading training data in {:02} seconds", + start.elapsed().as_secs_f32() + ); + + // If metric type is cosine, normalize the training data, and after this point, + // treat the metric type as L2. + let (training_data, dt) = if self.distance_type == DistanceType::Cosine { + let training_data = lance_linalg::kernels::normalize_fsl(&training_data)?; + (training_data, DistanceType::L2) + } else { + (training_data, self.distance_type) + }; + + info!("Start to train quantizer"); + let start = std::time::Instant::now(); + let quantizer = Q::build(&training_data, dt, quantizer_params)?; + info!( + "Trained quantizer in {:02} seconds", + start.elapsed().as_secs_f32() + ); + Ok(quantizer) + } + async fn shuffle_dataset(&mut self) -> Result<()> { let stream = self .dataset .scan() @@ -150,78 +238,229 @@ impl IvfIndexBuilder { .with_row_id() .try_into_stream() .await?; + self.shuffle_data(Some(stream)).await?; + Ok(()) + } + // shuffle the unindexed data and exsiting indices + // data must be with schema | ROW_ID | vector_column | + // the shuffled data will be with schema | ROW_ID | PART_ID | code_column | + pub async fn shuffle_data( + &mut self, + data: Option, + ) -> Result<&mut Self> { + if data.is_none() { + return Ok(self); + } + let data = data.unwrap(); + + let ivf = self.ivf.as_ref().ok_or(Error::invalid_input( + "IVF not set before shuffle data", + location!(), + ))?; + let quantizer = self.quantizer.clone().ok_or(Error::invalid_input( + "quantizer not set before shuffle data", + location!(), + ))?; + + let transformer = Arc::new( + lance_index::vector::ivf::new_ivf_transformer_with_quantizer( + ivf.centroids.clone().unwrap(), + self.distance_type, + &self.column, + quantizer.into(), + Some(0..ivf.num_partitions() as u32), + )?, + ); let mut transformed_stream = Box::pin( - stream - .map(move |batch| { - let ivf_transformer = transformer.clone(); - tokio::spawn(async move { ivf_transformer.transform(&batch?) }) - }) - .buffered(num_cpus::get()) - .map(|x| x.unwrap()) - .peekable(), + data.map(move |batch| { + let ivf_transformer = transformer.clone(); + tokio::spawn(async move { ivf_transformer.transform(&batch?) }) + }) + .buffered(num_cpus::get()) + .map(|x| x.unwrap()) + .peekable(), ); let batch = transformed_stream.as_mut().peek().await; let schema = match batch { Some(Ok(b)) => b.schema(), Some(Err(e)) => panic!("do this better: error reading first batch: {:?}", e), - None => panic!("no data"), + None => { + log::info!("no data to shuffle"); + self.shuffle_reader = Some(Box::new(IvfShufflerReader::new( + self.dataset.object_store.clone(), + self.temp_dir.clone(), + vec![0; ivf.num_partitions()], + ))); + return Ok(self); + } }; - self.shuffler - .shuffle(Box::new(RecordBatchStreamAdapter::new( - schema, - transformed_stream, - ))) - .await + self.shuffle_reader = Some( + self.shuffler + .shuffle(Box::new(RecordBatchStreamAdapter::new( + schema, + transformed_stream, + ))) + .await?, + ); + + Ok(self) + } + + async fn build_partitions(&mut self) -> Result<&mut Self> { + let ivf = self.ivf.as_ref().ok_or(Error::invalid_input( + "IVF not set before building partitions", + location!(), + ))?; + + let reader = self.shuffle_reader.as_ref().ok_or(Error::invalid_input( + "shuffle reader not set before building partitions", + location!(), + ))?; + + let partition_build_order = (0..ivf.num_partitions()) + .map(|partition_id| reader.partition_size(partition_id)) + .collect::>>()? + // sort by partition size in descending order + .into_iter() + .enumerate() + .sorted_unstable_by(|(_, a), (_, b)| b.cmp(a)) + .map(|(idx, _)| idx) + .collect::>(); + + let mut partition_sizes = vec![(0, 0); ivf.num_partitions()]; + for (i, &partition) in partition_build_order.iter().enumerate() { + log::info!( + "building partition {}, progress {}/{}", + partition, + i + 1, + ivf.num_partitions(), + ); + let mut batches = Vec::new(); + for existing_index in self.existing_indices.iter() { + let existing_index = existing_index + .as_any() + .downcast_ref::>() + .ok_or(Error::invalid_input( + "existing index is not IVF index", + location!(), + ))?; + + let part_storage = existing_index.load_partition_storage(partition).await?; + batches.extend( + self.take_vectors(part_storage.row_ids().cloned().collect_vec().as_ref()) + .await?, + ); + } + + match reader.partition_size(partition)? { + 0 => continue, + _ => { + let partition_data = + reader.read_partition(partition).await?.ok_or(Error::io( + format!("partition {} is empty", partition).as_str(), + location!(), + ))?; + batches.extend(partition_data.try_collect::>().await?); + } + } + + let num_rows = batches.iter().map(|b| b.num_rows()).sum::(); + if num_rows == 0 { + continue; + } + let mut batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; + if self.distance_type == DistanceType::Cosine { + let vectors = batch + .column_by_name(&self.column) + .ok_or(Error::invalid_input( + format!("column {} not found", self.column).as_str(), + location!(), + ))? + .as_fixed_size_list(); + let vectors = lance_linalg::kernels::normalize_fsl(vectors)?; + batch = batch.replace_column_by_name(&self.column, Arc::new(vectors))?; + } + + let sizes = self.build_partition(partition, &batch).await?; + partition_sizes[partition] = sizes; + log::info!( + "partition {} built, progress {}/{}", + partition, + i + 1, + ivf.num_partitions() + ); + } + self.partition_sizes = partition_sizes; + Ok(self) } async fn build_partition(&self, part_id: usize, batch: &RecordBatch) -> Result<(usize, usize)> { - let object_store = ObjectStore::local(); + let quantizer = self.quantizer.clone().ok_or(Error::invalid_input( + "quantizer not set before building partition", + location!(), + ))?; // build quantized vector storage - let storage = StorageBuilder::new( - self.column.clone(), - self.distance_type, - self.quantizer.clone(), - ) - .build(batch)?; - let path = self.temp_dir.child(format!("storage_part{}", part_id)); - let writer = object_store.create(&path).await?; - let mut writer = FileWriter::try_new( - writer, - path.to_string(), - storage.schema().as_ref().try_into()?, - Default::default(), - )?; - for batch in storage.to_batches()? { - writer.write_batch(&batch).await?; - } - let storage_len = writer.finish().await? as usize; + let object_store = ObjectStore::local(); + let storage_len = { + let storage = StorageBuilder::new(self.column.clone(), self.distance_type, quantizer) + .build(batch)?; + let path = self.temp_dir.child(format!("storage_part{}", part_id)); + let writer = object_store.create(&path).await?; + let mut writer = FileWriter::try_new( + writer, + path.to_string(), + storage.schema().as_ref().try_into()?, + Default::default(), + )?; + for batch in storage.to_batches()? { + writer.write_batch(&batch).await?; + } + writer.finish().await? as usize + }; // build the sub index, with in-memory storage - let sub_index = S::index_vectors(&storage, self.sub_index_params.clone())?; - let path = self.temp_dir.child(format!("index_part{}", part_id)); - let writer = object_store.create(&path).await?; - let index_batch = sub_index.to_batch()?; - let mut writer = FileWriter::try_new( - writer, - path.to_string(), - index_batch.schema_ref().as_ref().try_into()?, - Default::default(), - )?; - writer.write_batch(&index_batch).await?; - let index_len = writer.finish().await? as usize; + let index_len = { + let distance_type = match self.distance_type { + DistanceType::Cosine | DistanceType::Dot => DistanceType::L2, + _ => self.distance_type, + }; + let vectors = batch[&self.column].as_fixed_size_list(); + let flat_storage = FlatStorage::new(vectors.clone(), distance_type); + let sub_index = S::index_vectors(&flat_storage, self.sub_index_params.clone())?; + let path = self.temp_dir.child(format!("index_part{}", part_id)); + let writer = object_store.create(&path).await?; + let index_batch = sub_index.to_batch()?; + let mut writer = FileWriter::try_new( + writer, + path.to_string(), + index_batch.schema_ref().as_ref().try_into()?, + Default::default(), + )?; + writer.write_batch(&index_batch).await?; + writer.finish().await? as usize + }; Ok((storage_len, index_len)) } - async fn merge_partitions( - &self, - centroids: FixedSizeListArray, - partition_sizes: Vec<(usize, usize)>, - ) -> Result<()> { + async fn merge_partitions(&mut self) -> Result<()> { + let ivf = self.ivf.as_ref().ok_or(Error::invalid_input( + "IVF not set before merge partitions", + location!(), + ))?; + let quantizer = self.quantizer.clone().ok_or(Error::invalid_input( + "quantizer not set before merge partitions", + location!(), + ))?; + let partition_sizes = std::mem::take(&mut self.partition_sizes); + if partition_sizes.is_empty() { + return Err(Error::invalid_input("no partition to merge", location!())); + } + // prepare the final writers let storage_path = self.index_dir.child(INDEX_AUXILIARY_FILE_NAME); let index_path = self.index_dir.child(INDEX_FILE_NAME); @@ -234,12 +473,16 @@ impl IvfIndexBuilder { )?; // maintain the IVF partitions - let mut storage_ivf = IvfData::empty(); - let mut index_ivf = IvfData::with_centroids(Arc::new(centroids)); + let mut storage_ivf = IvfModel::empty(); + let mut index_ivf = IvfModel::new(ivf.centroids.clone().unwrap()); + let mut partition_storage_metadata = Vec::with_capacity(partition_sizes.len()); + let mut partition_index_metadata = Vec::with_capacity(partition_sizes.len()); let scheduler = ScanScheduler::new(Arc::new(ObjectStore::local()), 64); for (part_id, (storage_size, index_size)) in partition_sizes.into_iter().enumerate() { + log::info!("merging partition {}/{}", part_id, ivf.num_partitions()); if storage_size == 0 { - storage_ivf.add_partition(0) + storage_ivf.add_partition(0); + partition_storage_metadata.push(quantizer.metadata(None)?.to_string()); } else { let storage_part_path = self.temp_dir.child(format!("storage_part{}", part_id)); let reader = FileReader::try_open( @@ -268,10 +511,19 @@ impl IvfIndexBuilder { } storage_writer.as_mut().unwrap().write_batch(&batch).await?; storage_ivf.add_partition(batch.num_rows() as u32); + partition_storage_metadata.push( + reader + .schema() + .metadata + .get(STORAGE_METADATA_KEY) + .cloned() + .unwrap_or_default(), + ); } if index_size == 0 { - index_ivf.add_partition(0) + index_ivf.add_partition(0); + partition_index_metadata.push(String::new()); } else { let index_part_path = self.temp_dir.child(format!("index_part{}", part_id)); let reader = FileReader::try_open( @@ -292,7 +544,16 @@ impl IvfIndexBuilder { let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; index_writer.write_batch(&batch).await?; index_ivf.add_partition(batch.num_rows() as u32); + partition_index_metadata.push( + reader + .schema() + .metadata + .get(S::metadata_key()) + .cloned() + .unwrap_or_default(), + ); } + log::info!("merged partition {}/{}", part_id, ivf.num_partitions()); } let mut storage_writer = storage_writer.unwrap(); @@ -303,22 +564,52 @@ impl IvfIndexBuilder { .await?; storage_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string()); storage_writer.add_schema_metadata( - Q::metadata_key(), - self.quantizer.metadata(None)?.to_string(), + STORAGE_METADATA_KEY, + serde_json::to_string(&partition_storage_metadata)?, ); let index_ivf_pb = pb::Ivf::try_from(&index_ivf)?; - index_writer.add_schema_metadata(DISTANCE_TYPE_KEY, self.distance_type.to_string()); + let index_metadata = IndexMetadata { + index_type: S::name().to_string(), + distance_type: self.distance_type.to_string(), + }; + index_writer.add_schema_metadata( + INDEX_METADATA_SCHEMA_KEY, + serde_json::to_string(&index_metadata)?, + ); let ivf_buffer_pos = index_writer .add_global_buffer(index_ivf_pb.encode_to_vec().into()) .await?; index_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string()); + index_writer.add_schema_metadata( + S::metadata_key(), + serde_json::to_string(&partition_index_metadata)?, + ); storage_writer.finish().await?; index_writer.finish().await?; Ok(()) } + + // take vectors from the dataset + // used for reading vectors from existing indices + async fn take_vectors(&self, row_ids: &[u64]) -> Result> { + let column = self.column.clone(); + let object_store = self.dataset.object_store().clone(); + let projection = self.dataset.schema().project(&[column.as_str()])?; + // arrow uses i32 for index, so we chunk the row ids to avoid large batch causing overflow + let mut batches = Vec::new(); + for chunk in row_ids.chunks(object_store.block_size()) { + let batch = self.dataset.take_rows(chunk, &projection).await?; + let batch = batch.try_with_column( + ROW_ID_FIELD.clone(), + Arc::new(UInt64Array::from(chunk.to_vec())), + )?; + batches.push(batch); + } + Ok(batches) + } } #[cfg(test)] @@ -331,6 +622,8 @@ mod tests { use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::hnsw::HNSW; + use lance_index::vector::sq::builder::SQBuildParams; + use lance_index::vector::sq::ScalarQuantizer; use lance_index::vector::{ flat::index::{FlatIndex, FlatQuantizer}, ivf::IvfBuildParams, @@ -381,56 +674,50 @@ mod tests { let ivf_params = IvfBuildParams::default(); let index_dir = tempdir().unwrap(); let index_dir = Path::from(index_dir.path().to_str().unwrap()); - let shuffler = IvfShuffler::new( - dataset.object_store().clone(), - index_dir.child("shuffled"), - ivf_params.num_partitions, - ); + let shuffler = IvfShuffler::new(index_dir.child("shuffled"), ivf_params.num_partitions); - let fq = FlatQuantizer::new(DIM, DistanceType::L2); - let builder = super::IvfIndexBuilder::::new( + super::IvfIndexBuilder::::new( dataset, "vector".to_owned(), index_dir, DistanceType::L2, Box::new(shuffler), - ivf_params, + Some(ivf_params), + Some(()), (), - fq, ) + .unwrap() + .build() + .await .unwrap(); - - builder.build().await.unwrap(); } #[tokio::test] - async fn test_build_ivf_hnsw() { + async fn test_build_ivf_hnsw_sq() { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); let (dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; let ivf_params = IvfBuildParams::default(); let hnsw_params = HnswBuildParams::default(); + let sq_params = SQBuildParams::default(); let index_dir = tempdir().unwrap(); let index_dir = Path::from(index_dir.path().to_str().unwrap()); - let shuffler = IvfShuffler::new( - dataset.object_store().clone(), - index_dir.child("shuffled"), - ivf_params.num_partitions, - ); + let shuffler = IvfShuffler::new(index_dir.child("shuffled"), ivf_params.num_partitions); - let fq = FlatQuantizer::new(DIM, DistanceType::L2); - let builder = super::IvfIndexBuilder::::new( + super::IvfIndexBuilder::::new( dataset, "vector".to_owned(), index_dir, DistanceType::L2, Box::new(shuffler), - ivf_params, + Some(ivf_params), + Some(sq_params), hnsw_params, - fq, ) + .unwrap() + .build() + .await .unwrap(); - builder.build().await.unwrap(); } } diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index cd7498eadc..0214e83a99 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -14,11 +14,14 @@ mod test { use approx::assert_relative_eq; use arrow::array::AsArray; - use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch}; + use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, UInt32Array}; use arrow_schema::{DataType, Field, Schema}; use async_trait::async_trait; use deepsize::{Context, DeepSizeOf}; use lance_arrow::FixedSizeListArrayExt; + use lance_index::vector::ivf::storage::IvfModel; + use lance_index::vector::quantizer::{QuantizationType, Quantizer}; + use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{vector::Query, Index, IndexType}; use lance_io::{local::LocalObjectReader, traits::Reader}; use lance_linalg::distance::MetricType; @@ -29,7 +32,7 @@ mod test { use crate::{ index::{ prefilter::{DatasetPreFilter, PreFilter}, - vector::ivf::{IVFIndex, Ivf}, + vector::ivf::IVFIndex, }, session::Session, Result, @@ -63,6 +66,10 @@ mod test { self } + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + /// Retrieve index statistics as a JSON Value fn statistics(&self) -> Result { Ok(serde_json::Value::Null) @@ -93,6 +100,19 @@ mod test { Ok(self.ret_val.clone()) } + fn find_partitions(&self, _: &Query) -> Result { + unimplemented!("only for IVF") + } + + async fn search_in_partition( + &self, + _: usize, + _: &Query, + _: Arc, + ) -> Result { + unimplemented!("only for IVF") + } + fn is_loadable(&self) -> bool { true } @@ -122,6 +142,18 @@ mod test { Ok(()) } + fn ivf_model(&self) -> IvfModel { + unimplemented!("only for IVF") + } + fn quantizer(&self) -> Quantizer { + unimplemented!("only for IVF") + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + unimplemented!("only for IVF") + } + /// The metric type of this vector index. fn metric_type(&self) -> MetricType { self.metric_type @@ -132,10 +164,10 @@ mod test { async fn test_ivf_residual_handling() { let centroids = Float32Array::from_iter(vec![1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0]); let centroids = FixedSizeListArray::try_new_from_values(centroids, 2).unwrap(); - let mut ivf = Ivf::new(centroids); + let mut ivf = IvfModel::new(centroids); // Add 4 partitions for _ in 0..4 { - ivf.add_partition(0, 0); + ivf.add_partition(0); } // hold on to this pointer, because the index only holds a weak reference let session = Arc::new(Session::default()); diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index ffb032f1d4..c41ede6d80 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -13,7 +13,7 @@ use arrow_arith::numeric::sub; use arrow_array::{ cast::{as_struct_array, AsArray}, types::{Float16Type, Float32Type, Float64Type}, - Array, FixedSizeListArray, Float32Array, RecordBatch, StructArray, UInt32Array, + Array, FixedSizeListArray, RecordBatch, StructArray, UInt32Array, }; use arrow_ord::sort::sort_to_indices; use arrow_schema::{DataType, Schema}; @@ -30,20 +30,22 @@ use lance_file::{ format::MAGIC, writer::{FileWriter, FileWriterOptions}, }; -use lance_index::vector::v3::subindex::IvfSubIndex; +use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::quantizer::QuantizationType; +use lance_index::vector::v3::shuffler::IvfShuffler; +use lance_index::vector::v3::subindex::{IvfSubIndex, SubIndexType}; use lance_index::{ optimize::OptimizeOptions, vector::{ hnsw::{builder::HnswBuildParams, HNSWIndex, HNSW}, ivf::{ - builder::load_precomputed_partitions, - shuffler::shuffle_dataset, - storage::{IvfData, IVF_PARTITION_KEY}, - IvfBuildParams, + builder::load_precomputed_partitions, shuffler::shuffle_dataset, + storage::IVF_PARTITION_KEY, IvfBuildParams, }, pq::{PQBuildParams, ProductQuantizer}, quantizer::{Quantization, QuantizationMetadata, Quantizer}, - sq::{builder::SQBuildParams, ScalarQuantizer}, + sq::ScalarQuantizer, Query, VectorIndex, DIST_COL, }, Index, IndexMetadata, IndexType, INDEX_AUXILIARY_FILE_NAME, INDEX_METADATA_SCHEMA_KEY, @@ -63,7 +65,7 @@ use lance_linalg::{ distance::{DistanceType, Dot, MetricType, L2}, MatrixView, }; -use log::{debug, info}; +use log::info; use object_store::path::Path; use rand::{rngs::SmallRng, SeedableRng}; use roaring::RoaringBitmap; @@ -75,11 +77,12 @@ use uuid::Uuid; use self::io::write_hnsw_quantization_index_partitions; +use super::builder::IvfIndexBuilder; use super::{ pq::{build_pq_model, PQIndex}, utils::maybe_sample_training_data, }; -use crate::{dataset::builder::DatasetBuilder, index::vector::sq::build_sq_model}; +use crate::dataset::builder::DatasetBuilder; use crate::{ dataset::Dataset, index::{ @@ -100,7 +103,7 @@ pub struct IVFIndex { uuid: String, /// Ivf model - ivf: Ivf, + ivf: IvfModel, reader: Arc, @@ -129,7 +132,7 @@ impl IVFIndex { pub(crate) fn try_new( session: Arc, uuid: &str, - ivf: Ivf, + ivf: IvfModel, reader: Arc, sub_index: Arc, metric_type: MetricType, @@ -171,22 +174,26 @@ impl IVFIndex { let part_index = if let Some(part_idx) = session.index_cache.get_vector(&cache_key) { part_idx } else { - if partition_id >= self.ivf.lengths.len() { + if partition_id >= self.ivf.num_partitions() { return Err(Error::Index { message: format!( "partition id {} is out of range of {} partitions", partition_id, - self.ivf.lengths.len() + self.ivf.num_partitions() ), location: location!(), }); } - let offset = self.ivf.offsets[partition_id]; - let length = self.ivf.lengths[partition_id] as usize; + let range = self.ivf.row_range(partition_id); let idx = self .sub_index - .load_partition(self.reader.clone(), offset, length, partition_id) + .load_partition( + self.reader.clone(), + range.start, + range.end - range.start, + partition_id, + ) .await?; let idx: Arc = idx.into(); if write_cache { @@ -202,7 +209,7 @@ impl IVFIndex { /// Internal API with no stability guarantees. pub fn preprocess_query(&self, partition_id: usize, query: &Query) -> Result { if self.sub_index.use_residual() { - let partition_centroids = self.ivf.centroids.value(partition_id); + let partition_centroids = self.ivf.centroids.as_ref().unwrap().value(partition_id); let residual_key = sub(&query.key, &partition_centroids)?; let mut part_query = query.clone(); part_query.key = residual_key; @@ -211,34 +218,6 @@ impl IVFIndex { Ok(query.clone()) } } - - pub(crate) async fn search_in_partition( - &self, - partition_id: usize, - query: &Query, - pre_filter: Arc, - ) -> Result { - let part_index = self.load_partition(partition_id, true).await?; - - let query = self.preprocess_query(partition_id, query)?; - let batch = part_index.search(&query, pre_filter).await?; - Ok(batch) - } - - /// find the IVF partitions ids given the query vector. - /// - /// Internal API with no stability guarantees. - /// - /// Assumes the query vector is normalized if the metric type is cosine. - pub fn find_partitions(&self, query: &Query) -> Result { - let mt = if self.metric_type == MetricType::Cosine { - MetricType::L2 - } else { - self.metric_type - }; - - self.ivf.find_partitions(&query.key, query.nprobes, mt) - } } impl std::fmt::Debug for IVFIndex { @@ -265,6 +244,19 @@ pub(crate) async fn optimize_vector_indices( }); } + // try cast to v1 IVFIndex, + // fallback to v2 IVFIndex if it's not v1 IVFIndex + if !existing_indices[0].as_any().is::() { + return optimize_vector_indices_v2( + &dataset, + unindexed, + vector_column, + existing_indices, + options, + ) + .await; + } + let new_uuid = Uuid::new_v4(); let object_store = dataset.object_store(); let index_file = dataset @@ -277,7 +269,7 @@ pub(crate) async fn optimize_vector_indices( .as_any() .downcast_ref::() .ok_or(Error::Index { - message: "optimizing vector index: first index is not IVF".to_string(), + message: "optimizing vector index: the first index isn't IVF".to_string(), location: location!(), })?; @@ -325,6 +317,98 @@ pub(crate) async fn optimize_vector_indices( Ok((new_uuid, merged)) } +pub(crate) async fn optimize_vector_indices_v2( + dataset: &Dataset, + unindexed: Option, + vector_column: &str, + existing_indices: &[Arc], + options: &OptimizeOptions, +) -> Result<(Uuid, usize)> { + // Sanity check the indices + if existing_indices.is_empty() { + return Err(Error::Index { + message: "optimizing vector index: no existing index found".to_string(), + location: location!(), + }); + } + let existing_indices = existing_indices + .iter() + .cloned() + .map(|idx| idx.as_vector_index()) + .collect::>>()?; + + let new_uuid = Uuid::new_v4(); + let index_dir = dataset.indices_dir().child(new_uuid.to_string()); + let ivf_model = existing_indices[0].ivf_model(); + let quantizer = existing_indices[0].quantizer(); + let distance_type = existing_indices[0].metric_type(); + let num_partitions = ivf_model.num_partitions(); + let index_type = existing_indices[0].sub_index_type(); + + let temp_dir = tempfile::tempdir()?; + let temp_dir = temp_dir.path().to_str().unwrap().into(); + let shuffler = Box::new(IvfShuffler::new(temp_dir, num_partitions)); + let start_pos = if options.num_indices_to_merge > existing_indices.len() { + 0 + } else { + existing_indices.len() - options.num_indices_to_merge + }; + let indices_to_merge = existing_indices[start_pos..].to_vec(); + let merged_num = indices_to_merge.len(); + match index_type { + (SubIndexType::Flat, QuantizationType::Flat) => { + IvfIndexBuilder::::new_incremental( + dataset.clone(), + vector_column.to_owned(), + index_dir, + distance_type, + shuffler, + (), + )? + .with_ivf(ivf_model) + .with_quantizer(quantizer.try_into()?) + .with_existing_indices(indices_to_merge) + .shuffle_data(unindexed) + .await? + .build() + .await?; + } + + (SubIndexType::Hnsw, QuantizationType::Scalar) => { + IvfIndexBuilder::::new( + dataset.clone(), + vector_column.to_owned(), + index_dir, + distance_type, + shuffler, + None, + None, + // TODO: get the HNSW parameters from the existing indices + HnswBuildParams::default(), + )? + .with_ivf(ivf_model) + .with_quantizer(quantizer.try_into()?) + .with_existing_indices(indices_to_merge) + .shuffle_data(unindexed) + .await? + .build() + .await?; + } + + (sub_index_type, quantizer_type) => { + return Err(Error::Index { + message: format!( + "optimizing vector index: unsupported index type IVF_{}_{}", + sub_index_type, quantizer_type + ), + location: location!(), + }); + } + } + + Ok((new_uuid, merged_num)) +} + #[allow(clippy::too_many_arguments)] async fn optimize_ivf_pq_indices( first_idx: &IVFIndex, @@ -340,8 +424,8 @@ async fn optimize_ivf_pq_indices( let dim = first_idx.ivf.dimension(); // TODO: merge `lance::vector::ivf::IVF` and `lance-index::vector::ivf::Ivf`` implementations. - let ivf = lance_index::vector::ivf::Ivf::with_pq( - first_idx.ivf.centroids.clone(), + let ivf = lance_index::vector::ivf::IvfTransformer::with_pq( + first_idx.ivf.centroids.clone().unwrap(), metric_type, vector_column, pq_index.pq.clone(), @@ -349,10 +433,10 @@ async fn optimize_ivf_pq_indices( ); // Shuffled un-indexed data with partition. - let shuffled = if let Some(stream) = unindexed { - Some( + let shuffled = match unindexed { + Some(unindexed) => Some( shuffle_dataset( - stream, + unindexed, vector_column, ivf.into(), None, @@ -362,12 +446,11 @@ async fn optimize_ivf_pq_indices( None, ) .await?, - ) - } else { - None + ), + None => None, }; - let mut ivf_mut = Ivf::new(first_idx.ivf.centroids.clone()); + let mut ivf_mut = IvfModel::new(first_idx.ivf.centroids.clone().unwrap()); let start_pos = if options.num_indices_to_merge > existing_indices.len() { 0 @@ -420,8 +503,8 @@ async fn optimize_ivf_hnsw_indices( ) -> Result { let distance_type = first_idx.metric_type; let quantizer = hnsw_index.quantizer().clone(); - let ivf = lance_index::vector::ivf::new_ivf_with_quantizer( - first_idx.ivf.centroids.clone(), + let ivf = lance_index::vector::ivf::new_ivf_transformer_with_quantizer( + first_idx.ivf.centroids.clone().unwrap(), distance_type, vector_column, quantizer.clone(), @@ -429,10 +512,10 @@ async fn optimize_ivf_hnsw_indices( )?; // Shuffled un-indexed data with partition. - let shuffled = if let Some(stream) = unindexed { - Some( + let unindexed_data = match unindexed { + Some(unindexed) => Some( shuffle_dataset( - stream, + unindexed, vector_column, Arc::new(ivf), None, @@ -442,12 +525,11 @@ async fn optimize_ivf_hnsw_indices( None, ) .await?, - ) - } else { - None + ), + None => None, }; - let mut ivf_mut = Ivf::new(first_idx.ivf.centroids.clone()); + let mut ivf_mut = IvfModel::new(first_idx.ivf.centroids.clone().unwrap()); let start_pos = if options.num_indices_to_merge > existing_indices.len() { 0 @@ -549,7 +631,7 @@ async fn optimize_ivf_hnsw_indices( Some(&mut aux_writer), &mut ivf_mut, quantizer, - shuffled, + unindexed_data, Some(&indices_to_merge), ) .await?; @@ -558,12 +640,7 @@ async fn optimize_ivf_hnsw_indices( let hnsw_metadata_json = json!(hnsw_metadata); writer.add_metadata(IVF_PARTITION_KEY, &hnsw_metadata_json.to_string()); - // Convert ['Ivf'] to [`IvfData`] for new index format - let mut ivf_data = IvfData::with_centroids(Arc::new(ivf_mut.centroids.clone())); - for length in ivf_mut.lengths { - ivf_data.add_partition(length); - } - ivf_data.write(&mut writer).await?; + ivf_mut.write(&mut writer).await?; writer.finish().await?; // Write the aux file @@ -637,19 +714,22 @@ impl Index for IVFIndex { self } + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + fn index_type(&self) -> IndexType { IndexType::Vector } fn statistics(&self) -> Result { - let partitions_statistics = self - .ivf - .lengths - .iter() - .map(|&len| IvfIndexPartitionStatistics { size: len }) + let partitions_statistics = (0..self.ivf.num_partitions()) + .map(|part_id| IvfIndexPartitionStatistics { + size: self.ivf.partition_size(part_id) as u32, + }) .collect::>(); - let centroid_vecs = centroids_to_vectors(&self.ivf.centroids)?; + let centroid_vecs = centroids_to_vectors(self.ivf.centroids.as_ref().unwrap())?; Ok(serde_json::to_value(IvfIndexStatistics { index_type: "IVF".to_string(), @@ -712,6 +792,34 @@ impl VectorIndex for IVFIndex { Ok(as_struct_array(&taken_distances).into()) } + /// find the IVF partitions ids given the query vector. + /// + /// Internal API with no stability guarantees. + /// + /// Assumes the query vector is normalized if the metric type is cosine. + fn find_partitions(&self, query: &Query) -> Result { + let mt = if self.metric_type == MetricType::Cosine { + MetricType::L2 + } else { + self.metric_type + }; + + self.ivf.find_partitions(&query.key, query.nprobes, mt) + } + + async fn search_in_partition( + &self, + partition_id: usize, + query: &Query, + pre_filter: Arc, + ) -> Result { + let part_index = self.load_partition(partition_id, true).await?; + + let query = self.preprocess_query(partition_id, query)?; + let batch = part_index.search(&query, pre_filter).await?; + Ok(batch) + } + fn is_loadable(&self) -> bool { false } @@ -753,6 +861,19 @@ impl VectorIndex for IVFIndex { }) } + fn ivf_model(&self) -> IvfModel { + self.ivf.clone() + } + + fn quantizer(&self) -> Quantizer { + unimplemented!("only for v2 IVFIndex") + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + unimplemented!("only for v2 IVFIndex") + } + fn metric_type(&self) -> MetricType { self.metric_type } @@ -779,7 +900,7 @@ pub struct IvfPQIndexMetadata { pub(crate) metric_type: MetricType, /// IVF model - pub(crate) ivf: Ivf, + pub(crate) ivf: IvfModel, /// Product Quantizer pub(crate) pq: Arc, @@ -835,112 +956,6 @@ impl TryFrom<&IvfPQIndexMetadata> for pb::Index { }) } } -/// Ivf Model -#[derive(Debug, Clone)] -pub(crate) struct Ivf { - /// Centroids of each partition. - /// - /// It is a 2-D `(num_partitions * dimension)` of vector array. - pub(crate) centroids: FixedSizeListArray, - - /// Offset of each partition in the file. - offsets: Vec, - - /// Number of vectors in each partition. - lengths: Vec, -} - -impl Ivf { - pub(super) fn new(centroids: FixedSizeListArray) -> Self { - Self { - centroids, - offsets: vec![], - lengths: vec![], - } - } - - /// Ivf model dimension. - pub(super) fn dimension(&self) -> usize { - self.centroids.value_length() as usize - } - - /// Number of IVF partitions. - fn num_partitions(&self) -> usize { - self.centroids.len() - } - - /// Use the query vector to find `nprobes` closest partitions. - fn find_partitions( - &self, - query: &dyn Array, - nprobes: usize, - metric_type: MetricType, - ) -> Result { - let internal = - lance_index::vector::ivf::new_ivf(self.centroids.clone(), metric_type, vec![]); - internal.find_partitions(query, nprobes) - } - - /// Add the offset and length of one partition. - pub(super) fn add_partition(&mut self, offset: usize, len: u32) { - self.offsets.push(offset); - self.lengths.push(len); - } -} - -/// Convert IvfModel to protobuf. -impl TryFrom<&Ivf> for pb::Ivf { - type Error = Error; - - fn try_from(ivf: &Ivf) -> Result { - if ivf.offsets.len() != ivf.centroids.len() { - return Err(Error::io( - "Ivf model has not been populated".to_string(), - location!(), - )); - } - Ok(Self { - centroids: vec![], - offsets: ivf.offsets.iter().map(|o| *o as u64).collect(), - lengths: ivf.lengths.clone(), - centroids_tensor: Some((&ivf.centroids).try_into()?), - }) - } -} - -/// Convert IvfModel to protobuf. -impl TryFrom<&pb::Ivf> for Ivf { - type Error = Error; - - fn try_from(proto: &pb::Ivf) -> Result { - let centroids = if let Some(tensor) = proto.centroids_tensor.as_ref() { - debug!("Ivf: loading IVF centroids from index format v2"); - FixedSizeListArray::try_from(tensor)? - } else { - debug!("Ivf: loading IVF centroids from index format v1"); - // For backward-compatibility - let f32_centroids = Float32Array::from(proto.centroids.clone()); - let dimension = f32_centroids.len() / proto.lengths.len(); - FixedSizeListArray::try_new_from_values(f32_centroids, dimension as i32)? - }; - - let mut ivf = Self { - centroids, - offsets: proto.offsets.iter().map(|o| *o as usize).collect(), - lengths: proto.lengths.clone(), - }; - - if ivf.offsets.is_empty() && !ivf.lengths.is_empty() { - let mut offset = 0; - for len in &ivf.lengths { - ivf.offsets.push(offset); - offset += *len as usize; - } - } - - Ok(ivf) - } -} fn sanity_check<'a>(dataset: &'a Dataset, column: &str) -> Result<&'a Field> { let Some(field) = dataset.schema().field(column) else { @@ -1035,7 +1050,7 @@ pub(super) async fn build_ivf_model( dim: usize, metric_type: MetricType, params: &IvfBuildParams, -) -> Result { +) -> Result { if let Some(centroids) = params.centroids.as_ref() { info!("Pre-computed IVF centroids is provided, skip IVF training"); if centroids.values().len() != params.num_partitions * dim { @@ -1048,7 +1063,7 @@ pub(super) async fn build_ivf_model( location: location!(), }); } - return Ok(Ivf::new(centroids.as_ref().clone())); + return Ok(IvfModel::new(centroids.as_ref().clone())); } let sample_size_hint = params.num_partitions * params.sample_rate; @@ -1088,7 +1103,7 @@ async fn build_ivf_model_and_pq( metric_type: MetricType, ivf_params: &IvfBuildParams, pq_params: &PQBuildParams, -) -> Result<(Ivf, Arc)> { +) -> Result<(IvfModel, Arc)> { sanity_check_params(ivf_params, pq_params)?; info!( @@ -1122,40 +1137,6 @@ async fn build_ivf_model_and_pq( Ok((ivf_model, pq)) } -async fn build_ivf_model_and_sq( - dataset: &Dataset, - column: &str, - metric_type: MetricType, - ivf_params: &IvfBuildParams, - sq_params: &SQBuildParams, -) -> Result<(Ivf, ScalarQuantizer)> { - sanity_check_ivf_params(ivf_params)?; - - info!( - "Building vector index: IVF{},SQ{}, metric={}", - ivf_params.num_partitions, sq_params.num_bits, metric_type, - ); - - let field = sanity_check(dataset, column)?; - let dim = if let DataType::FixedSizeList(_, d) = field.data_type() { - d as usize - } else { - return Err(Error::Index { - message: format!( - "VectorIndex requires the column data type to be fixed size list of floats, got {}", - field.data_type() - ), - location: location!(), - }); - }; - - let ivf_model = build_ivf_model(dataset, column, dim, metric_type, ivf_params).await?; - - let sq = build_sq_model(dataset, column, metric_type, sq_params).await?; - - Ok((ivf_model, sq)) -} - async fn scan_index_field_stream( dataset: &Dataset, column: &str, @@ -1250,41 +1231,6 @@ pub async fn build_ivf_hnsw_pq_index( .await } -#[allow(clippy::too_many_arguments)] -pub async fn build_ivf_hnsw_sq_index( - dataset: &Dataset, - column: &str, - index_name: &str, - uuid: &str, - metric_type: MetricType, - ivf_params: &IvfBuildParams, - hnsw_params: &HnswBuildParams, - sq_params: &SQBuildParams, -) -> Result<()> { - let (ivf_model, sq) = - build_ivf_model_and_sq(dataset, column, metric_type, ivf_params, sq_params).await?; - let stream = scan_index_field_stream(dataset, column).await?; - let precomputed_partitions = load_precomputed_partitions_if_available(ivf_params).await?; - - write_ivf_hnsw_file( - dataset, - column, - index_name, - uuid, - &[], - ivf_model, - Quantizer::Scalar(sq), - metric_type, - hnsw_params, - stream, - precomputed_partitions, - ivf_params.shuffle_partition_batches, - ivf_params.shuffle_partition_concurrency, - ivf_params.precomputed_shuffle_buffers.clone(), - ) - .await -} - struct RemapPageTask { offset: usize, length: u32, @@ -1317,7 +1263,7 @@ impl RemapPageTask { Ok(self) } - async fn write(self, writer: &mut ObjectWriter, ivf: &mut Ivf) -> Result<()> { + async fn write(self, writer: &mut ObjectWriter, ivf: &mut IvfModel) -> Result<()> { let page = self.page.as_ref().expect("Load was not called"); let page: &PQIndex = page .as_any() @@ -1367,7 +1313,7 @@ pub(crate) async fn remap_index_file( .map(|task| task.load_and_remap(reader.clone(), index, mapping)) .buffered(num_cpus::get()); - let mut ivf = Ivf { + let mut ivf = IvfModel { centroids: index.ivf.centroids.clone(), offsets: Vec::with_capacity(index.ivf.offsets.len()), lengths: Vec::with_capacity(index.ivf.lengths.len()), @@ -1415,7 +1361,7 @@ async fn write_ivf_pq_file( index_name: &str, uuid: &str, transformers: &[Box], - mut ivf: Ivf, + mut ivf: IvfModel, pq: Arc, metric_type: MetricType, stream: impl RecordBatchStream + Unpin + 'static, @@ -1481,7 +1427,7 @@ async fn write_ivf_hnsw_file( _index_name: &str, uuid: &str, _transformers: &[Box], - mut ivf: Ivf, + mut ivf: IvfModel, quantizer: Quantizer, distance_type: DistanceType, hnsw_params: &HnswBuildParams, @@ -1598,12 +1544,7 @@ async fn write_ivf_hnsw_file( let hnsw_metadata_json = json!(hnsw_metadata); writer.add_metadata(IVF_PARTITION_KEY, &hnsw_metadata_json.to_string()); - // Convert ['Ivf'] to [`IvfData`] for new index format - let mut ivf_data = IvfData::with_centroids(Arc::new(ivf.centroids.clone())); - for length in ivf.lengths { - ivf_data.add_partition(length); - } - ivf_data.write(&mut writer).await?; + ivf.write(&mut writer).await?; writer.finish().await?; // Write the aux file @@ -1617,7 +1558,7 @@ async fn do_train_ivf_model( dimension: usize, metric_type: MetricType, params: &IvfBuildParams, -) -> Result +) -> Result where T::Native: Dot + L2 + Normalize, { @@ -1634,7 +1575,7 @@ where params.sample_rate, ) .await?; - Ok(Ivf::new(FixedSizeListArray::try_new_from_values( + Ok(IvfModel::new(FixedSizeListArray::try_new_from_values( centroids, dimension as i32, )?)) @@ -1645,7 +1586,7 @@ async fn train_ivf_model( data: &FixedSizeListArray, distance_type: DistanceType, params: &IvfBuildParams, -) -> Result { +) -> Result { assert!( distance_type != DistanceType::Cosine, "Cosine metric should be done by normalized L2 distance", @@ -1684,11 +1625,12 @@ mod tests { use std::ops::Range; use arrow_array::types::UInt64Type; - use arrow_array::{RecordBatchIterator, RecordBatchReader, UInt64Array}; + use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader, UInt64Array}; use arrow_schema::Field; use itertools::Itertools; use lance_core::utils::address::RowAddress; use lance_core::ROW_ID; + use lance_index::vector::sq::builder::SQBuildParams; use lance_linalg::distance::l2_distance_batch; use lance_testing::datagen::{ generate_random_array, generate_random_array_with_range, generate_random_array_with_seed, @@ -2214,13 +2156,14 @@ mod tests { let ivf_model = build_ivf_model(&dataset, "vector", DIM, MetricType::L2, &ivf_params) .await .unwrap(); - assert_eq!(2, ivf_model.centroids.len()); - assert_eq!(32, ivf_model.centroids.value_length()); + assert_eq!(2, ivf_model.centroids.as_ref().unwrap().len()); + assert_eq!(32, ivf_model.centroids.as_ref().unwrap().value_length()); assert_eq!(2, ivf_model.num_partitions()); // All centroids values should be in the range [1000, 1100] ivf_model .centroids + .unwrap() .values() .as_primitive::() .values() @@ -2241,13 +2184,14 @@ mod tests { let ivf_model = build_ivf_model(&dataset, "vector", DIM, MetricType::Cosine, &ivf_params) .await .unwrap(); - assert_eq!(2, ivf_model.centroids.len()); - assert_eq!(32, ivf_model.centroids.value_length()); + assert_eq!(2, ivf_model.centroids.as_ref().unwrap().len()); + assert_eq!(32, ivf_model.centroids.as_ref().unwrap().value_length()); assert_eq!(2, ivf_model.num_partitions()); // All centroids values should be in the range [1000, 1100] ivf_model .centroids + .unwrap() .values() .as_primitive::() .values() @@ -2574,91 +2518,6 @@ mod tests { ); } - async fn test_create_ivf_hnsw_sq(distance_type: DistanceType, expected_recall: f32) { - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - - let nlist = 4; - let (mut dataset, vector_array) = generate_test_dataset(test_uri, 0.0..1.0).await; - - let ivf_params = IvfBuildParams::new(nlist); - let sq_params = SQBuildParams::default(); - let hnsw_params = HnswBuildParams::default(); - let params = VectorIndexParams::with_ivf_hnsw_sq_params( - distance_type, - ivf_params, - hnsw_params, - sq_params, - ); - - dataset - .create_index(&["vector"], IndexType::Vector, None, ¶ms, false) - .await - .unwrap(); - - let mat = MatrixView::::try_from(vector_array.as_ref()).unwrap(); - let query = vector_array.value(0); - let query = query.as_primitive::(); - let k = 100; - let results = dataset - .scan() - .with_row_id() - .nearest("vector", query, k) - .unwrap() - .nprobs(nlist) - .try_into_stream() - .await - .unwrap() - .try_collect::>() - .await - .unwrap(); - assert_eq!(1, results.len()); - assert_eq!(k, results[0].num_rows()); - - let row_ids = results[0] - .column_by_name(ROW_ID) - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .map(|v| v.unwrap() as u32) - .collect::>(); - let dists = results[0] - .column_by_name("_distance") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap() - .values() - .to_vec(); - - let results = dists.into_iter().zip(row_ids.into_iter()).collect_vec(); - let gt = ground_truth(&mat, query.values(), k, distance_type); - - let results_set = results.iter().map(|r| r.1).collect::>(); - let gt_set = gt.iter().map(|r| r.1).collect::>(); - - let recall = results_set.intersection(>_set).count() as f32 / k as f32; - assert!( - recall >= expected_recall, - "recall: {}\n results: {:?}\n\ngt: {:?}", - recall, - results, - gt, - ); - } - - #[tokio::test] - async fn test_create_ivf_hnsw_sq_cosine() { - test_create_ivf_hnsw_sq(DistanceType::Cosine, 0.9).await - } - - #[tokio::test] - async fn test_create_ivf_hnsw_sq_dot() { - test_create_ivf_hnsw_sq(DistanceType::Dot, 0.8).await - } - #[tokio::test] async fn test_create_ivf_hnsw_with_empty_partition() { let test_dir = tempdir().unwrap(); @@ -2778,6 +2637,8 @@ mod tests { assert!(ivf_idx .ivf .centroids + .as_ref() + .unwrap() .values() .as_primitive::() .values() diff --git a/rust/lance/src/index/vector/ivf/builder.rs b/rust/lance/src/index/vector/ivf/builder.rs index d50f344481..5b06d60619 100644 --- a/rust/lance/src/index/vector/ivf/builder.rs +++ b/rust/lance/src/index/vector/ivf/builder.rs @@ -6,6 +6,7 @@ use std::ops::Range; use std::sync::Arc; use lance_file::writer::FileWriter; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::quantizer::Quantizer; use lance_table::io::manifest::ManifestDescribing; use object_store::path::Path; @@ -15,13 +16,13 @@ use tracing::instrument; use lance_core::{traits::DatasetTakeRows, Error, Result, ROW_ID}; use lance_index::vector::{ hnsw::{builder::HnswBuildParams, HnswMetadata}, - ivf::{shuffler::shuffle_dataset, storage::IvfData}, + ivf::shuffler::shuffle_dataset, pq::ProductQuantizer, }; use lance_io::{stream::RecordBatchStream, traits::Writer}; use lance_linalg::distance::MetricType; -use crate::index::vector::ivf::{io::write_pq_partitions, Ivf}; +use crate::index::vector::ivf::io::write_pq_partitions; use super::io::write_hnsw_quantization_index_partitions; @@ -34,7 +35,7 @@ pub(super) async fn build_partitions( writer: &mut dyn Writer, data: impl RecordBatchStream + Unpin + 'static, column: &str, - ivf: &mut Ivf, + ivf: &mut IvfModel, pq: Arc, metric_type: MetricType, part_range: Range, @@ -57,8 +58,8 @@ pub(super) async fn build_partitions( }); } - let ivf_model = lance_index::vector::ivf::Ivf::with_pq( - ivf.centroids.clone(), + let ivf_transformer = lance_index::vector::ivf::IvfTransformer::with_pq( + ivf.centroids.clone().unwrap(), metric_type, column, pq.clone(), @@ -68,7 +69,7 @@ pub(super) async fn build_partitions( let stream = shuffle_dataset( data, column, - ivf_model.into(), + ivf_transformer.into(), precomputed_partitons, ivf.num_partitions() as u32, shuffle_partition_batches, @@ -93,7 +94,7 @@ pub(super) async fn build_hnsw_partitions( auxiliary_writer: Option<&mut FileWriter>, data: impl RecordBatchStream + Unpin + 'static, column: &str, - ivf: &mut Ivf, + ivf: &mut IvfModel, quantizer: Quantizer, metric_type: MetricType, hnsw_params: &HnswBuildParams, @@ -102,7 +103,7 @@ pub(super) async fn build_hnsw_partitions( shuffle_partition_batches: usize, shuffle_partition_concurrency: usize, precomputed_shuffle_buffers: Option<(Path, Vec)>, -) -> Result<(Vec, IvfData)> { +) -> Result<(Vec, IvfModel)> { let schema = data.schema(); if schema.column_with_name(column).is_none() { return Err(Error::Schema { @@ -117,8 +118,8 @@ pub(super) async fn build_hnsw_partitions( }); } - let ivf_model = lance_index::vector::ivf::new_ivf_with_quantizer( - ivf.centroids.clone(), + let ivf_model = lance_index::vector::ivf::new_ivf_transformer_with_quantizer( + ivf.centroids.clone().unwrap(), metric_type, column, quantizer.clone(), diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index 74a0a17fb1..0e7e4ebd7e 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -21,16 +21,12 @@ use lance_core::Error; use lance_file::reader::FileReader; use lance_file::writer::FileWriter; use lance_index::scalar::IndexWriter; -use lance_index::vector::hnsw::builder::HNSW_METADATA_KEY; +use lance_index::vector::hnsw::HNSW; use lance_index::vector::hnsw::{builder::HnswBuildParams, HnswMetadata}; -use lance_index::vector::ivf::storage::IvfData; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::ProductQuantizer; +use lance_index::vector::quantizer::{Quantization, Quantizer}; use lance_index::vector::v3::subindex::IvfSubIndex; -use lance_index::vector::{ - quantizer::{Quantization, Quantizer}, - sq::ScalarQuantizer, - storage::VectorStore, -}; use lance_index::vector::{PART_ID_COLUMN, PQ_CODE_COLUMN}; use lance_io::encodings::plain::PlainEncoder; use lance_io::object_store::ObjectStore; @@ -45,10 +41,10 @@ use snafu::{location, Location}; use tempfile::TempDir; use tokio::sync::Semaphore; -use super::{IVFIndex, Ivf}; +use super::IVFIndex; use crate::dataset::ROW_ID; use crate::index::vector::pq::{build_pq_storage, PQIndex}; -use crate::index::vector::sq::build_sq_storage; + use crate::Result; // TODO: make it configurable, limit by the number of CPU cores & memory @@ -147,7 +143,7 @@ async fn merge_streams( /// TODO: migrate this function to `lance-index` crate. pub(super) async fn write_pq_partitions( writer: &mut dyn Writer, - ivf: &mut Ivf, + ivf: &mut IvfModel, streams: Option>>>, existing_indices: Option<&[&IVFIndex]>, ) -> Result<()> { @@ -228,7 +224,7 @@ pub(super) async fn write_pq_partitions( .await?; let total_records = row_id_array.iter().map(|a| a.len()).sum::(); - ivf.add_partition(writer.tell().await?, total_records as u32); + ivf.add_partition_with_offset(writer.tell().await?, total_records as u32); if total_records > 0 { let pq_refs = pq_array.iter().map(|a| a.as_ref()).collect::>(); PlainEncoder::write(writer, &pq_refs).await?; @@ -253,11 +249,11 @@ pub(super) async fn write_hnsw_quantization_index_partitions( hnsw_params: &HnswBuildParams, writer: &mut FileWriter, mut auxiliary_writer: Option<&mut FileWriter>, - ivf: &mut Ivf, + ivf: &mut IvfModel, quantizer: Quantizer, streams: Option>>>, existing_indices: Option<&[&IVFIndex]>, -) -> Result<(Vec, IvfData)> { +) -> Result<(Vec, IvfModel)> { let hnsw_params = Arc::new(hnsw_params.clone()); let mut streams_heap = BinaryHeap::new(); @@ -384,14 +380,14 @@ pub(super) async fn write_hnsw_quantization_index_partitions( })); } - let mut aux_ivf = IvfData::empty(); + let mut aux_ivf = IvfModel::empty(); let mut hnsw_metadata = Vec::with_capacity(ivf.num_partitions()); for (part_id, task) in tasks.into_iter().enumerate() { let offset = writer.len(); let num_rows = task.await??; if num_rows == 0 { - ivf.add_partition(offset, 0); + ivf.add_partition(0); aux_ivf.add_partition(0); hnsw_metadata.push(HnswMetadata::default()); continue; @@ -414,9 +410,9 @@ pub(super) async fn write_hnsw_quantization_index_partitions( .await?; writer.write(&batches).await?; - ivf.add_partition(offset, (writer.len() - offset) as u32); + ivf.add_partition((writer.len() - offset) as u32); hnsw_metadata.push(serde_json::from_str( - part_reader.schema().metadata[HNSW_METADATA_KEY].as_str(), + part_reader.schema().metadata[HNSW::metadata_key()].as_str(), )?); std::mem::drop(part_reader); object_store.delete(part_file).await?; @@ -498,13 +494,7 @@ async fn build_hnsw_quantization_partition( aux_writer.unwrap(), )), - Quantizer::Scalar(sq) => tokio::spawn(build_and_write_sq_storage( - metric_type, - row_ids, - vectors, - sq, - aux_writer.unwrap(), - )), + _ => unreachable!("IVF_HNSW_SQ has been moved to v2 index builder"), }; let index_rows = futures::join!(build_hnsw, build_store).0?; @@ -547,26 +537,6 @@ async fn build_and_write_pq_storage( Ok(()) } -async fn build_and_write_sq_storage( - distance_type: DistanceType, - row_ids: Arc, - vectors: Arc, - sq: ScalarQuantizer, - mut writer: FileWriter, -) -> Result<()> { - let storage = spawn_cpu(move || { - let storage = build_sq_storage(distance_type, row_ids, vectors, sq)?; - Ok(storage) - }) - .await?; - - for batch in storage.to_batches()? { - writer.write_record_batch(batch.clone()).await?; - } - writer.finish().await?; - Ok(()) -} - #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 3b5ef10bb3..d88ebf6d60 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -4,6 +4,7 @@ //! IVF - Inverted File index. use core::fmt; +use std::marker::PhantomData; use std::{ any::Any, collections::HashMap, @@ -19,9 +20,16 @@ use arrow_array::{RecordBatch, StructArray, UInt32Array}; use async_trait::async_trait; use deepsize::DeepSizeOf; use futures::prelude::stream::{self, StreamExt, TryStreamExt}; +use lance_arrow::RecordBatchExt; use lance_core::{cache::DEFAULT_INDEX_CACHE_SIZE, Error, Result}; use lance_encoding::decoder::{DecoderMiddlewareChain, FilterExpression}; use lance_file::v2::reader::FileReader; +use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::hnsw::HNSW; +use lance_index::vector::ivf::storage::IvfModel; +use lance_index::vector::quantizer::{QuantizationType, Quantizer}; +use lance_index::vector::sq::ScalarQuantizer; +use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ pb, vector::{ @@ -30,6 +38,7 @@ use lance_index::{ }, Index, IndexType, INDEX_AUXILIARY_FILE_NAME, INDEX_FILE_NAME, }; +use lance_index::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; use lance_io::{ object_store::ObjectStore, scheduler::ScanScheduler, traits::Reader, ReadBatchParams, }; @@ -46,20 +55,28 @@ use crate::{ session::Session, }; -use super::{centroids_to_vectors, Ivf, IvfIndexPartitionStatistics, IvfIndexStatistics}; +use super::{centroids_to_vectors, IvfIndexPartitionStatistics, IvfIndexStatistics}; + +#[derive(Debug)] +struct PartitionEntry { + index: S, + storage: Q::Storage, +} + /// IVF Index. #[derive(Debug)] -pub struct IVFIndex { +pub struct IVFIndex { uuid: String, /// Ivf model - ivf: Ivf, + ivf: IvfModel, reader: FileReader, - storage: IvfQuantizationStorage, + sub_index_metadata: Vec, + storage: IvfQuantizationStorage, /// Index in each partition. - sub_index_cache: Cache>, + partition_cache: Cache>>, distance_type: DistanceType, @@ -68,16 +85,18 @@ pub struct IVFIndex { /// The session cache, used when fetching pages #[allow(dead_code)] session: Weak, + + _marker: PhantomData, } -impl DeepSizeOf for IVFIndex { +impl DeepSizeOf for IVFIndex { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.uuid.deep_size_of_children(context) + self.storage.deep_size_of_children(context) // Skipping session since it is a weak ref } } -impl IVFIndex { +impl IVFIndex { /// Create a new IVF index. pub(crate) async fn try_new( object_store: Arc, @@ -95,17 +114,18 @@ impl IVFIndex { DecoderMiddlewareChain::default(), ) .await?; - let distance_type = DistanceType::try_from( + let index_metadata: IndexMetadata = serde_json::from_str( index_reader .schema() .metadata - .get(DISTANCE_TYPE_KEY) + .get(INDEX_METADATA_SCHEMA_KEY) .ok_or(Error::Index { message: format!("{} not found", DISTANCE_TYPE_KEY), location: location!(), })? .as_str(), )?; + let distance_type = DistanceType::try_from(index_metadata.distance_type.as_str())?; let ivf_pos = index_reader .schema() @@ -121,7 +141,17 @@ impl IVFIndex { location: location!(), })?; let ivf_pb_bytes = index_reader.read_global_buffer(ivf_pos).await?; - let ivf = Ivf::try_from(&pb::Ivf::decode(ivf_pb_bytes)?)?; + let ivf = IvfModel::try_from(pb::Ivf::decode(ivf_pb_bytes)?)?; + + let sub_index_metadata = index_reader + .schema() + .metadata + .get(S::metadata_key()) + .ok_or(Error::Index { + message: format!("{} not found", S::metadata_key()), + location: location!(), + })?; + let sub_index_metadata: Vec = serde_json::from_str(sub_index_metadata)?; let storage_reader = FileReader::try_open( scheduler @@ -135,81 +165,95 @@ impl IVFIndex { DecoderMiddlewareChain::default(), ) .await?; - let storage = IvfQuantizationStorage::open(storage_reader).await?; + let storage = IvfQuantizationStorage::try_new(storage_reader).await?; Ok(Self { uuid, ivf, reader: index_reader, storage, - sub_index_cache: Cache::new(DEFAULT_INDEX_CACHE_SIZE as u64), + partition_cache: Cache::new(DEFAULT_INDEX_CACHE_SIZE as u64), + sub_index_metadata, distance_type, session, + _marker: PhantomData, }) } #[instrument(level = "debug", skip(self))] - pub async fn load_partition(&self, partition_id: usize, write_cache: bool) -> Result> { + pub async fn load_partition( + &self, + partition_id: usize, + write_cache: bool, + ) -> Result>> { let cache_key = format!("{}-ivf-{}", self.uuid, partition_id); - let part_index = if let Some(part_idx) = self.sub_index_cache.get(&cache_key) { + let part_entry = if let Some(part_idx) = self.partition_cache.get(&cache_key) { part_idx } else { - if partition_id >= self.ivf.lengths.len() { + if partition_id >= self.ivf.num_partitions() { return Err(Error::Index { message: format!( "partition id {} is out of range of {} partitions", partition_id, - self.ivf.lengths.len() + self.ivf.num_partitions() ), location: location!(), }); } - let offset = self.ivf.offsets[partition_id]; - let length = self.ivf.lengths[partition_id] as usize; - let batches = self - .reader - .read_stream( - ReadBatchParams::Range(offset..offset + length), - 4096, - 16, - FilterExpression::no_filter(), - )? - .peekable() - .try_collect::>() - .await?; let schema = Arc::new(self.reader.schema().as_ref().into()); - let batch = concat_batches(&schema, batches.iter())?; - let idx = Arc::new(I::load(batch)?); + let batch = match self.reader.metadata().num_rows { + 0 => RecordBatch::new_empty(schema), + _ => { + let batches = self + .reader + .read_stream( + ReadBatchParams::Range(self.ivf.row_range(partition_id)), + u32::MAX, + 1, + FilterExpression::no_filter(), + )? + .try_collect::>() + .await?; + concat_batches(&schema, batches.iter())? + } + }; + let batch = batch.add_metadata( + S::metadata_key().to_owned(), + self.sub_index_metadata[partition_id].clone(), + )?; + let idx = S::load(batch)?; + let storage = self.load_partition_storage(partition_id).await?; + let partition_entry = Arc::new(PartitionEntry { + index: idx, + storage, + }); if write_cache { - self.sub_index_cache.insert(cache_key.clone(), idx.clone()); + self.partition_cache + .insert(cache_key.clone(), partition_entry.clone()); } - idx + partition_entry }; - Ok(part_index) - } - async fn search_in_partition( - &self, - partition_id: usize, - query: &Query, - pre_filter: Arc, - ) -> Result { - let part_index = self.load_partition(partition_id, true).await?; + Ok(part_entry) + } - let query = self.preprocess_query(partition_id, query)?; - let storage = self.storage.load_partition(partition_id).await?; - let param = (&query).into(); - pre_filter.wait_for_ready().await?; - part_index.search(query.key, query.k, param, &storage, pre_filter) + pub async fn load_partition_storage(&self, partition_id: usize) -> Result { + self.storage.load_partition::(partition_id).await } /// preprocess the query vector given the partition id. /// /// Internal API with no stability guarantees. pub fn preprocess_query(&self, partition_id: usize, query: &Query) -> Result { - if I::use_residual() { - let partition_centroids = self.ivf.centroids.value(partition_id); + if S::use_residual() { + let partition_centroids = + self.ivf + .centroid(partition_id) + .ok_or_else(|| Error::Index { + message: format!("partition centroid {} does not exist", partition_id), + location: location!(), + })?; let residual_key = sub(&query.key, &partition_centroids)?; let mut part_query = query.clone(); part_query.key = residual_key; @@ -218,20 +262,10 @@ impl IVFIndex { Ok(query.clone()) } } - - pub fn find_partitions(&self, query: &Query) -> Result { - let dt = if self.distance_type == DistanceType::Cosine { - DistanceType::L2 - } else { - self.distance_type - }; - - self.ivf.find_partitions(&query.key, query.nprobes, dt) - } } #[async_trait] -impl Index for IVFIndex { +impl Index for IVFIndex { fn as_any(&self) -> &dyn Any { self } @@ -240,19 +274,22 @@ impl Index for IVFIndex) -> Result> { + Ok(self) + } + fn index_type(&self) -> IndexType { IndexType::Vector } fn statistics(&self) -> Result { - let partitions_statistics = self - .ivf - .lengths - .iter() - .map(|&len| IvfIndexPartitionStatistics { size: len }) + let partitions_statistics = (0..self.ivf.num_partitions()) + .map(|part_id| IvfIndexPartitionStatistics { + size: self.ivf.partition_size(part_id) as u32, + }) .collect::>(); - let centroid_vecs = centroids_to_vectors(&self.ivf.centroids)?; + let centroid_vecs = centroids_to_vectors(self.ivf.centroids.as_ref().unwrap())?; Ok(serde_json::to_value(IvfIndexStatistics { index_type: "IVF".to_string(), @@ -274,10 +311,11 @@ impl Index for IVFIndex VectorIndex - for IVFIndex +impl VectorIndex + for IVFIndex { async fn search(&self, query: &Query, pre_filter: Arc) -> Result { + pre_filter.wait_for_ready().await?; let mut query = query.clone(); if self.distance_type == DistanceType::Cosine { let key = normalize_arrow(&query.key)?; @@ -312,6 +350,45 @@ impl Result { + let dt = if self.distance_type == DistanceType::Cosine { + DistanceType::L2 + } else { + self.distance_type + }; + + self.ivf.find_partitions(&query.key, query.nprobes, dt) + } + + // async fn append(&self, batches: Vec) -> Result<()> { + // IvfIndexBuilder::new( + // dataset, + // column, + // index_dir, + // distance_type, + // shuffler, + // ivf_params, + // sub_index_params, + // quantizer_params, + // ) + // } + + async fn search_in_partition( + &self, + partition_id: usize, + query: &Query, + pre_filter: Arc, + ) -> Result { + let part_entry = self.load_partition(partition_id, true).await?; + + let query = self.preprocess_query(partition_id, query)?; + let param = (&query).into(); + // pre_filter.wait_for_ready().await?; + part_entry + .index + .search(query.key, query.k, param, &part_entry.storage, pre_filter) + } + fn is_loadable(&self) -> bool { false } @@ -353,21 +430,44 @@ impl IvfModel { + self.ivf.clone() + } + + fn quantizer(&self) -> Quantizer { + self.storage.quantizer::().unwrap() + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + (S::name().try_into().unwrap(), Q::quantization_type()) + } + fn metric_type(&self) -> DistanceType { self.distance_type } } +pub type IvfFlatIndex = IVFIndex; +pub type IvfHnswSqIndex = IVFIndex; + #[cfg(test)] mod tests { + use std::collections::HashSet; use std::{collections::HashMap, ops::Range, sync::Arc}; + use arrow::datatypes::UInt64Type; use arrow::{array::AsArray, datatypes::Float32Type}; use arrow_array::{Array, FixedSizeListArray, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use lance_arrow::FixedSizeListArrayExt; - use lance_index::DatasetIndexExt; + use lance_core::ROW_ID; + use lance_index::vector::hnsw::builder::HnswBuildParams; + use lance_index::vector::ivf::IvfBuildParams; + use lance_index::vector::sq::builder::SQBuildParams; + use lance_index::vector::DIST_COL; + use lance_index::{DatasetIndexExt, IndexType}; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_range; use tempfile::tempdir; @@ -395,7 +495,9 @@ mod tests { )]) .with_metadata(metadata) .into(); - let array = Arc::new(FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap()); + let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + let fsl = lance_linalg::kernels::normalize_fsl(&fsl).unwrap(); + let array = Arc::new(fsl); let batch = RecordBatch::try_new(schema.clone(), vec![array.clone()]).unwrap(); let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone()); @@ -427,7 +529,7 @@ mod tests { async fn test_build_ivf_flat() { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let (mut dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; + let (mut dataset, vectors) = generate_test_dataset(test_uri, 0.0..1.0).await; let nlist = 16; let params = VectorIndexParams::ivf_flat(nlist, DistanceType::L2); @@ -442,52 +544,123 @@ mod tests { .await .unwrap(); - // TODO: test query after we replace the IVFIndex with the new one - // let query = vectors.value(0); - // let k = 100; - // let result = dataset - // .scan() - // .nearest("vector", query.as_primitive::(), k) - // .unwrap() - // .nprobs(nlist) - // .with_row_id() - // .try_into_batch() - // .await - // .unwrap(); - - // let row_ids = result - // .column_by_name(ROW_ID) - // .unwrap() - // .as_primitive::() - // .values() - // .to_vec(); - // let dists = result - // .column_by_name("_distance") - // .unwrap() - // .as_primitive::() - // .values() - // .to_vec(); - // let results = dists - // .into_iter() - // .zip(row_ids.into_iter()) - // .collect::>(); - // let row_ids = results.iter().map(|(_, id)| *id).collect::>(); - - // let gt = ground_truth( - // &vectors, - // query.as_primitive::().values(), - // k, - // DistanceType::L2, - // ); - // let gt_set = gt.iter().map(|r| r.1).collect::>(); - - // let recall = row_ids.intersection(>_set).count() as f32 / k as f32; - // assert!( - // recall >= 1.0, - // "recall: {}\n results: {:?}\n\ngt: {:?}", - // recall, - // results, - // gt, - // ); + let query = vectors.value(0); + let k = 100; + let result = dataset + .scan() + .nearest("vector", query.as_primitive::(), k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + + let row_ids = result[ROW_ID] + .as_primitive::() + .values() + .to_vec(); + let dists = result[DIST_COL] + .as_primitive::() + .values() + .to_vec(); + let results = dists + .into_iter() + .zip(row_ids.into_iter()) + .collect::>(); + let row_ids = results.iter().map(|(_, id)| *id).collect::>(); + + let gt = ground_truth( + &vectors, + query.as_primitive::().values(), + k, + DistanceType::L2, + ); + let gt_set = gt.iter().map(|r| r.1).collect::>(); + + let recall = row_ids.intersection(>_set).count() as f32 / k as f32; + assert!( + recall >= 1.0, + "recall: {}\n results: {:?}\n\ngt: {:?}", + recall, + results, + gt, + ); + } + + async fn test_create_ivf_hnsw_sq(distance_type: DistanceType) { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + + let nlist = 4; + let (mut dataset, vectors) = generate_test_dataset(test_uri, 0.0..1.0).await; + + let ivf_params = IvfBuildParams::new(nlist); + let sq_params = SQBuildParams::default(); + let hnsw_params = HnswBuildParams::default(); + let params = VectorIndexParams::with_ivf_hnsw_sq_params( + distance_type, + ivf_params, + hnsw_params, + sq_params, + ); + + dataset + .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + let query = vectors.value(0); + let k = 100; + let result = dataset + .scan() + .nearest("vector", query.as_primitive::(), k) + .unwrap() + .nprobs(nlist) + .with_row_id() + .try_into_batch() + .await + .unwrap(); + + let row_ids = result[ROW_ID] + .as_primitive::() + .values() + .to_vec(); + let dists = result[DIST_COL] + .as_primitive::() + .values() + .to_vec(); + let results = dists + .into_iter() + .zip(row_ids.into_iter()) + .collect::>(); + let row_ids = results.iter().map(|(_, id)| *id).collect::>(); + + let gt = ground_truth( + &vectors, + query.as_primitive::().values(), + k, + distance_type, + ); + let gt_set = gt.iter().map(|r| r.1).collect::>(); + + let recall = row_ids.intersection(>_set).count() as f32 / k as f32; + assert!( + recall >= 0.9, + "recall: {}\n results: {:?}\n\ngt: {:?}", + recall, + results, + gt, + ); + } + + #[tokio::test] + async fn test_create_ivf_hnsw_sq_cosine() { + test_create_ivf_hnsw_sq(DistanceType::Cosine).await + } + + #[tokio::test] + async fn test_create_ivf_hnsw_sq_dot() { + test_create_ivf_hnsw_sq(DistanceType::Dot).await } } diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 2980a522be..81ece189f0 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -6,6 +6,7 @@ use std::{any::Any, collections::HashMap}; use arrow::compute::concat; use arrow_array::types::{Float16Type, Float32Type, Float64Type}; +use arrow_array::UInt32Array; use arrow_array::{ cast::{as_primitive_array, AsArray}, Array, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array, @@ -18,8 +19,10 @@ use deepsize::DeepSizeOf; use lance_core::utils::tokio::spawn_cpu; use lance_core::ROW_ID; use lance_core::{utils::address::RowAddress, ROW_ID_FIELD}; +use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::storage::ProductQuantizationStorage; -use lance_index::vector::quantizer::Quantization; +use lance_index::vector::quantizer::{Quantization, QuantizationType, Quantizer}; +use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ vector::{pq::ProductQuantizer, Query, DIST_COL}, Index, IndexType, @@ -36,7 +39,6 @@ use tracing::{instrument, span, Level}; pub use lance_index::vector::pq::{PQBuildParams, ProductQuantizerImpl}; use lance_linalg::kernels::normalize_fsl; -use super::ivf::Ivf; use super::VectorIndex; use crate::index::prefilter::PreFilter; use crate::index::vector::utils::maybe_sample_training_data; @@ -132,6 +134,10 @@ impl Index for PQIndex { self } + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + fn index_type(&self) -> IndexType { IndexType::Vector } @@ -211,6 +217,19 @@ impl VectorIndex for PQIndex { .await } + fn find_partitions(&self, _: &Query) -> Result { + unimplemented!("only for IVF") + } + + async fn search_in_partition( + &self, + _: usize, + _: &Query, + _: Arc, + ) -> Result { + unimplemented!("only for IVF") + } + fn is_loadable(&self) -> bool { true } @@ -289,6 +308,18 @@ impl VectorIndex for PQIndex { Ok(()) } + fn ivf_model(&self) -> IvfModel { + unimplemented!("only for IVF") + } + fn quantizer(&self) -> Quantizer { + unimplemented!("only for IVF") + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + (SubIndexType::Flat, QuantizationType::Product) + } + fn metric_type(&self) -> MetricType { self.metric_type } @@ -309,7 +340,7 @@ pub(super) async fn build_pq_model( dim: usize, metric_type: MetricType, params: &PQBuildParams, - ivf: Option<&Ivf>, + ivf: Option<&IvfModel>, ) -> Result> { if let Some(codebook) = ¶ms.codebook { let mt = if metric_type == MetricType::Cosine { @@ -381,7 +412,11 @@ pub(super) async fn build_pq_model( // Compute residual for PQ training. // // TODO: consolidate IVF models to `lance_index`. - let ivf2 = lance_index::vector::ivf::new_ivf(ivf.centroids.clone(), MetricType::L2, vec![]); + let ivf2 = lance_index::vector::ivf::new_ivf_transformer( + ivf.centroids.clone().unwrap(), + MetricType::L2, + vec![], + ); span!(Level::INFO, "compute residual for PQ training") .in_scope(|| ivf2.compute_residual(&training_data))? } else { @@ -470,7 +505,7 @@ mod tests { let centroids = generate_random_array_with_range(4 * DIM, -1.0..1.0); let fsl = FixedSizeListArray::try_new_from_values(centroids, DIM as i32).unwrap(); - let ivf = Ivf::new(fsl); + let ivf = IvfModel::new(fsl); let params = PQBuildParams::new(16, 8); let pq = build_pq_model(&dataset, "vector", DIM, MetricType::L2, ¶ms, Some(&ivf)) .await @@ -533,7 +568,11 @@ mod tests { let vectors = normalize_fsl(&vectors).unwrap(); let row = vectors.slice(0, 1); - let ivf2 = lance_index::vector::ivf::new_ivf(ivf.centroids.clone(), MetricType::L2, vec![]); + let ivf2 = lance_index::vector::ivf::new_ivf_transformer( + ivf.centroids.clone().unwrap(), + MetricType::L2, + vec![], + ); let residual_query = ivf2.compute_residual(&row).unwrap(); let pq_code = pq.transform(&residual_query).unwrap(); diff --git a/rust/lance/src/index/vector/sq.rs b/rust/lance/src/index/vector/sq.rs deleted file mode 100644 index 0542255f20..0000000000 --- a/rust/lance/src/index/vector/sq.rs +++ /dev/null @@ -1,75 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -use std::sync::Arc; - -use arrow::datatypes::Float32Type; -use arrow_array::{Array, RecordBatch}; -use lance_core::{Result, ROW_ID}; -use lance_index::vector::{ - quantizer::Quantization, - sq::{builder::SQBuildParams, storage::ScalarQuantizationStorage, ScalarQuantizer}, -}; -use lance_linalg::{distance::MetricType, kernels::normalize_fsl}; - -use crate::{index::vector::utils::maybe_sample_training_data, Dataset}; - -pub(super) async fn build_sq_model( - dataset: &Dataset, - column: &str, - metric_type: MetricType, - params: &SQBuildParams, -) -> Result { - log::info!("Start to train SQ code: SQ{}", params.num_bits); - let expected_sample_size = 2usize.pow(params.num_bits as u32) * params.sample_rate; - log::info!( - "Loading training data for SQ. Sample size: {}", - expected_sample_size - ); - let start = std::time::Instant::now(); - let mut training_data = - maybe_sample_training_data(dataset, column, expected_sample_size).await?; - log::info!( - "Finished loading training data in {:02} seconds", - start.elapsed().as_secs_f32() - ); - - log::info!( - "starting to compute partitions for SQ training, sample size: {}", - training_data.value_length() - ); - - if metric_type == MetricType::Cosine { - log::info!("Normalize training data for SQ training: Cosine"); - training_data = normalize_fsl(&training_data)?; - } - - log::info!("Start train SQ: params={:#?}", params); - let sq = params.build(&training_data, MetricType::L2)?; - log::info!( - "Trained SQ{}[{:?}] in: {} seconds", - sq.num_bits(), - sq.bounds(), - start.elapsed().as_secs_f32() - ); - Ok(sq) -} - -pub fn build_sq_storage( - metric_type: MetricType, - row_ids: Arc, - vectors: Arc, - sq: ScalarQuantizer, -) -> Result { - let code_column = sq.transform::(vectors.as_ref())?; - std::mem::drop(vectors); - - let batch = RecordBatch::try_from_iter_with_nullable(vec![ - (ROW_ID, row_ids, true), - (sq.column(), code_column, false), - ])?; - let store = - ScalarQuantizationStorage::try_new(sq.num_bits(), metric_type, sq.bounds(), [batch])?; - - Ok(store) -} diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 76afb79bee..797f4b71c5 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -28,9 +28,7 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; use lance_core::utils::mask::{RowIdMask, RowIdTreeMap}; use lance_core::{ROW_ID, ROW_ID_FIELD}; -use lance_index::vector::{ - flat::flat_search, Query, VectorIndex, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN, -}; +use lance_index::vector::{flat::flat_search, Query, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN}; use lance_io::stream::RecordBatchStream; use lance_linalg::distance::DistanceType; use lance_linalg::kernels::normalize_arrow; @@ -44,7 +42,6 @@ use tracing::{instrument, Instrument}; use crate::dataset::scanner::DatasetRecordBatchStream; use crate::dataset::Dataset; use crate::index::prefilter::{DatasetPreFilter, FilterLoader}; -use crate::index::vector::ivf::IVFIndex; use crate::index::DatasetIndexInternalExt; use crate::{Error, Result}; @@ -487,12 +484,7 @@ impl ExecutionPlan for ANNIvfPartitionExec { let ds = ds.clone(); async move { - let raw_index = ds.open_vector_index(&query.column, &uuid).await?; - let index = raw_index.as_any().downcast_ref::().ok_or( - DataFusionError::Execution( - "ANNIVFPartitionExec: index is not a IVF type".to_string(), - ), - )?; + let index = ds.open_vector_index(&query.column, &uuid).await?; let mut query = query.clone(); if index.metric_type() == DistanceType::Cosine { @@ -734,13 +726,7 @@ impl ExecutionPlan for ANNIvfSubIndexExec { .map(move |result| { let query = query.clone(); async move { - let (part_id, (raw_index, pre_filter)) = result?; - - let index = raw_index.as_any().downcast_ref::().ok_or( - DataFusionError::Execution( - "ANNSubIndexExec: sub-index is not a IVF type".to_string(), - ), - )?; + let (part_id, (index, pre_filter)) = result?; let mut query = query.clone(); if index.metric_type() == DistanceType::Cosine { diff --git a/rust/lance/src/session/index_extension.rs b/rust/lance/src/session/index_extension.rs index 20d26cca2e..4148d399ba 100644 --- a/rust/lance/src/session/index_extension.rs +++ b/rust/lance/src/session/index_extension.rs @@ -65,10 +65,13 @@ mod test { sync::{atomic::AtomicBool, Arc}, }; - use arrow_array::RecordBatch; + use arrow_array::{RecordBatch, UInt32Array}; use arrow_schema::Schema; use deepsize::DeepSizeOf; use lance_file::writer::{FileWriter, FileWriterOptions}; + use lance_index::vector::ivf::storage::IvfModel; + use lance_index::vector::quantizer::{QuantizationType, Quantizer}; + use lance_index::vector::v3::subindex::SubIndexType; use lance_index::{ vector::{hnsw::VECTOR_ID_FIELD, Query}, DatasetIndexExt, Index, IndexMetadata, IndexType, INDEX_FILE_NAME, @@ -100,6 +103,10 @@ mod test { self } + fn as_vector_index(self: Arc) -> Result> { + Ok(self) + } + fn statistics(&self) -> Result { Ok(json!(())) } @@ -116,7 +123,20 @@ mod test { #[async_trait::async_trait] impl VectorIndex for MockIndex { async fn search(&self, _: &Query, _: Arc) -> Result { - todo!("panic") + unimplemented!() + } + + fn find_partitions(&self, _: &Query) -> Result { + unimplemented!() + } + + async fn search_in_partition( + &self, + _: usize, + _: &Query, + _: Arc, + ) -> Result { + unimplemented!() } fn is_loadable(&self) -> bool { @@ -137,17 +157,29 @@ mod test { _: usize, _: usize, ) -> Result> { - todo!("panic") + unimplemented!() } fn row_ids(&self) -> Box> { - todo!("panic") + unimplemented!() } fn remap(&mut self, _: &HashMap>) -> Result<()> { Ok(()) } + fn ivf_model(&self) -> IvfModel { + unimplemented!() + } + fn quantizer(&self) -> Quantizer { + unimplemented!() + } + + /// the index type of this vector index. + fn sub_index_type(&self) -> (SubIndexType, QuantizationType) { + unimplemented!() + } + fn metric_type(&self) -> MetricType { MetricType::L2 }