Skip to content

Commit

Permalink
use retrain index in consolidate_updates()
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan committed Jul 16, 2024
1 parent 3e5797f commit fecd658
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 28 deletions.
24 changes: 12 additions & 12 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def delete_batch(self, external_ids: np.array, timestamp: int = None):

def consolidate_updates(self, retrain_index: bool = False, **kwargs):
"""
Consolidates updates by merging updates form the updates table into the base index.
Consolidates updates by merging updates from the updates table into the base index.
The consolidation process is used to avoid query latency degradation as more updates
are added to the index. It triggers a base index re-indexing, merging the non-consolidated
Expand All @@ -489,10 +489,10 @@ def consolidate_updates(self, retrain_index: bool = False, **kwargs):
----------
retrain_index: bool
If true, retrain the index. If false, reuse data from the previous index.
For IVF_FLAT retraining means we will recompute the centroids - when doing so you can
pass any ingest() arguments used to configure computing centroids and we will use them
when recomputing the centroids. Otherwise, if false, we will reuse the centroids from
the previous index.
For IVF_FLAT and IVF_PQ retraining means we will recompute the centroids - when doing
so you can pass any ingest() arguments used to configure computing centroids and we will
use them when recomputing the centroids. Otherwise, if false, we will reuse the centroids
from the previous index.
**kwargs
Extra kwargs passed here are passed to `ingest` function.
"""
Expand All @@ -516,18 +516,19 @@ def consolidate_updates(self, retrain_index: bool = False, **kwargs):
tiledb.consolidate(self.updates_array_uri, config=conf)
tiledb.vacuum(self.updates_array_uri, config=conf)

copy_centroids_uri = None
# We don't copy the centroids if self.partitions=0 because this means our index was previously empty.
should_pass_copy_centroids_uri = (
self.index_type == "IVF_FLAT" and not retrain_index and self.partitions > 0
)
if should_pass_copy_centroids_uri:
if self.index_type == "IVF_FLAT" and not retrain_index and self.partitions > 0:
# Make sure the user didn't pass an incorrect number of partitions.
if "partitions" in kwargs and self.partitions != kwargs["partitions"]:
raise ValueError(
f"The passed partitions={kwargs['partitions']} is different than the number of partitions ({self.partitions}) from when the index was created - this is an issue because with retrain_index=True, the partitions from the previous index will be used; to fix, set retrain_index=False, don't pass partitions, or pass the correct number of partitions."
)
# We pass partitions through kwargs so that we don't pass it twice.
kwargs["partitions"] = self.partitions
copy_centroids_uri = self.centroids_uri
if self.index_type == "IVF_PQ" and not retrain_index:
copy_centroids_uri = True

# print('[index@consolidate_updates] self.centroids_uri', self.centroids_uri)
print("[index@consolidate_updates] self.uri", self.uri)
Expand All @@ -539,6 +540,7 @@ def consolidate_updates(self, retrain_index: bool = False, **kwargs):
)
print("[index@consolidate_updates] self.max_timestamp", max_timestamp)
print("[index@consolidate_updates] self.storage_version", self.storage_version)
print("[index@consolidate_updates] copy_centroids_uri", copy_centroids_uri)

new_index = ingest(
index_type=self.index_type,
Expand All @@ -550,9 +552,7 @@ def consolidate_updates(self, retrain_index: bool = False, **kwargs):
updates_uri=self.updates_array_uri,
index_timestamp=max_timestamp,
storage_version=self.storage_version,
copy_centroids_uri=self.centroids_uri
if should_pass_copy_centroids_uri
else None,
copy_centroids_uri=copy_centroids_uri,
config=self.config,
**kwargs,
)
Expand Down
37 changes: 23 additions & 14 deletions apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,25 @@ def ingest(
raise ValueError("source_uri should not be provided alongside input_vectors")
if source_type and input_vectors:
raise ValueError("source_type should not be provided alongside input_vectors")

for variable in [
"training_input_vectors",
"training_source_uri",
"training_source_type",
]:
if index_type != "IVF_FLAT" and locals().get(variable) is not None:
raise ValueError(
f"{variable} should only be provided with index_type IVF_FLAT"
)

if (
index_type != "IVF_FLAT"
and index_type != "IVF_PQ"
and locals().get("copy_centroids_uri") is not None
):
raise ValueError(
"copy_centroids_uri should only be provided with index_type IVF_FLAT"
)

if training_source_uri and training_sample_size != -1:
raise ValueError(
Expand Down Expand Up @@ -257,7 +276,7 @@ def ingest(
raise ValueError(
"training_sample_size should not be provided alongside copy_centroids_uri"
)
if copy_centroids_uri is not None and partitions == -1:
if index_type == "IVF_FLAT" and copy_centroids_uri is not None and partitions == -1:
raise ValueError(
"partitions should be provided if copy_centroids_uri is provided (set partitions to the number of centroids in copy_centroids_uri)"
)
Expand All @@ -266,16 +285,6 @@ def ingest(
raise ValueError(
"training_sample_size should only be provided with index_type IVF_FLAT"
)
for variable in [
"copy_centroids_uri",
"training_input_vectors",
"training_source_uri",
"training_source_type",
]:
if index_type != "IVF_FLAT" and locals().get(variable) is not None:
raise ValueError(
f"{variable} should only be provided with index_type IVF_FLAT"
)

for variable in [
"copy_centroids_uri",
Expand Down Expand Up @@ -1573,7 +1582,7 @@ def ingest_type_erased(
dimensions: int,
size: int,
batch: int,
arrays_created: bool,
retrain_index: bool,
config: Optional[Mapping[str, Any]] = None,
verbose: bool = False,
trace_id: Optional[str] = None,
Expand Down Expand Up @@ -1605,7 +1614,7 @@ def ingest_type_erased(
trace_id=trace_id,
)

if arrays_created and index_type == "IVF_PQ":
if retrain_index and index_type == "IVF_PQ":
# For IVF_PQ, we cannot re-ingest the data, as we only store the PQ encoded
# vectors. Instead leave the centroids and just update the stored vectors.
print(
Expand Down Expand Up @@ -2330,7 +2339,7 @@ def scale_resources(min_resource, max_resource, max_input_size, input_size):
dimensions=dimensions,
size=size,
batch=input_vectors_batch_size,
arrays_created=arrays_created,
retrain_index=copy_centroids_uri is None,
config=config,
verbose=verbose,
trace_id=trace_id,
Expand Down
2 changes: 0 additions & 2 deletions src/include/index/ivf_pq_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,6 @@ class ivf_pq_index {
training_set_ids.end(),
feature_vectors_.ids());

auto num_unique_labels = ::num_vectors(flat_ivf_centroids_);

train_pq(training_set); // cluster_centroids_, distance_tables_
train_ivf(training_set); // flat_ivf_centroids_
std::cout << "[ivf_pq_index@add] pq_ivf_centroids_ = "
Expand Down

0 comments on commit fecd658

Please sign in to comment.