Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…arch into jparismorgan/ivf-pq-consolidation
  • Loading branch information
jparismorgan committed Jul 16, 2024
2 parents 7fd5590 + b1986c7 commit 3e5797f
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 226 deletions.
1 change: 1 addition & 0 deletions apis/python/src/tiledb/vector_search/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def clear_history(
vspy.IndexIVFPQ.clear_history(ctx, uri, timestamp)
else:
raise ValueError(f"Unsupported index_type: {index_type}")
group.close()
return

ingestion_timestamps = [
Expand Down
150 changes: 53 additions & 97 deletions apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,37 +1821,25 @@ def ingest_vectors_udf(
)
if source_type == "TILEDB_ARRAY":
logger.debug("Start indexing")
if index_timestamp is None:
ivf_index_tdb(
dtype=vector_type,
db_uri=source_uri,
external_ids_uri=external_ids_uri,
deleted_ids=StdVector_u64(updated_ids),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=part,
end=part_end,
nthreads=threads,
config=config,
)
else:
ivf_index_tdb(
dtype=vector_type,
db_uri=source_uri,
external_ids_uri=external_ids_uri,
deleted_ids=StdVector_u64(updated_ids),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=part,
end=part_end,
nthreads=threads,
timestamp=index_timestamp,
config=config,
)
ivf_index_tdb(
dtype=vector_type,
db_uri=source_uri,
external_ids_uri=external_ids_uri,
deleted_ids=StdVector_u64(updated_ids),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=part,
end=part_end,
nthreads=threads,
**(
{"timestamp": index_timestamp}
if index_timestamp is not None
else {}
),
config=config,
)
else:
in_vectors = read_input_vectors(
source_uri=source_uri,
Expand All @@ -1874,41 +1862,25 @@ def ingest_vectors_udf(
trace_id=trace_id,
)
logger.debug("Start indexing")
if index_timestamp is None:
ivf_index(
dtype=vector_type,
db=array_to_matrix(
np.transpose(in_vectors).astype(vector_type)
),
external_ids=StdVector_u64(external_ids),
deleted_ids=StdVector_u64(updated_ids),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=part,
end=part_end,
nthreads=threads,
config=config,
)
else:
ivf_index(
dtype=vector_type,
db=array_to_matrix(
np.transpose(in_vectors).astype(vector_type)
),
external_ids=StdVector_u64(external_ids),
deleted_ids=StdVector_u64(updated_ids),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=part,
end=part_end,
nthreads=threads,
timestamp=index_timestamp,
config=config,
)
ivf_index(
dtype=vector_type,
db=array_to_matrix(np.transpose(in_vectors).astype(vector_type)),
external_ids=StdVector_u64(external_ids),
deleted_ids=StdVector_u64(updated_ids),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=part,
end=part_end,
nthreads=threads,
**(
{"timestamp": index_timestamp}
if index_timestamp is not None
else {}
),
config=config,
)

def ingest_additions_udf(
index_group_uri: str,
Expand Down Expand Up @@ -1949,37 +1921,21 @@ def ingest_additions_udf(
return

logger.debug(f"Ingesting additions {partial_write_array_index_uri}")
if index_timestamp is None:
ivf_index(
dtype=vector_type,
db=array_to_matrix(np.transpose(additions_vectors).astype(vector_type)),
external_ids=StdVector_u64(additions_external_ids),
deleted_ids=StdVector_u64(np.array([], np.uint64)),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=write_offset,
end=0,
nthreads=threads,
config=config,
)
else:
ivf_index(
dtype=vector_type,
db=array_to_matrix(np.transpose(additions_vectors).astype(vector_type)),
external_ids=StdVector_u64(additions_external_ids),
deleted_ids=StdVector_u64(np.array([], np.uint64)),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=write_offset,
end=0,
nthreads=threads,
timestamp=index_timestamp,
config=config,
)
ivf_index(
dtype=vector_type,
db=array_to_matrix(np.transpose(additions_vectors).astype(vector_type)),
external_ids=StdVector_u64(additions_external_ids),
deleted_ids=StdVector_u64(np.array([], np.uint64)),
centroids_uri=centroids_uri,
parts_uri=partial_write_array_parts_uri,
index_array_uri=partial_write_array_index_uri,
id_uri=partial_write_array_ids_uri,
start=write_offset,
end=0,
nthreads=threads,
**({"timestamp": index_timestamp} if index_timestamp is not None else {}),
config=config,
)

def compute_partition_indexes_udf(
index_group_uri: str,
Expand Down
8 changes: 4 additions & 4 deletions apis/python/src/tiledb/vector_search/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ static void declare_ivf_index(py::module& m, const std::string& suffix) {
m.def(
("ivf_index_" + suffix).c_str(),
[](tiledb::Context& ctx,
const ColMajorMatrix<T>& db,
const ColMajorMatrix<T>& input_vectors,
const std::vector<uint64_t>& external_ids,
const std::vector<uint64_t>& deleted_ids,
const std::string& centroids_uri,
Expand All @@ -316,7 +316,7 @@ static void declare_ivf_index(py::module& m, const std::string& suffix) {
uint64_t timestamp) -> int {
return detail::ivf::ivf_index<T, uint64_t, float>(
ctx,
db,
input_vectors,
external_ids,
deleted_ids,
centroids_uri,
Expand All @@ -338,7 +338,7 @@ static void declare_ivf_index_tdb(py::module& m, const std::string& suffix) {
m.def(
("ivf_index_tdb_" + suffix).c_str(),
[](tiledb::Context& ctx,
const std::string& db_uri,
const std::string& input_vectors_uri,
const std::string& external_ids_uri,
const std::vector<uint64_t>& deleted_ids,
const std::string& centroids_uri,
Expand All @@ -351,7 +351,7 @@ static void declare_ivf_index_tdb(py::module& m, const std::string& suffix) {
uint64_t timestamp) -> int {
return detail::ivf::ivf_index<T, uint64_t, float>(
ctx,
db_uri,
input_vectors_uri,
external_ids_uri,
deleted_ids,
centroids_uri,
Expand Down
38 changes: 20 additions & 18 deletions apis/python/test/test_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,6 @@ def test_ingestion_timetravel(tmp_path):
timestamp=20,
)

# if index_type == "IVF_PQ":
# # TODO(SC-48888): Fix consolidation for IVF_PQ.
# continue
index = index.consolidate_updates()

# We still have no results before timestamp 10.
Expand Down Expand Up @@ -784,6 +781,9 @@ def test_ingestion_timetravel(tmp_path):
second_num_edges = num_edges_history[1]

# Clear all history at timestamp 19.
# With type-erased indexes, we cannot call clear_history() while the index is open because they
# open up a TileDB Array during query(). Deleting fragments while the array is open is not allowed.
index = None
Index.clear_history(uri=index_uri, timestamp=19)

with tiledb.Group(index_uri, "r") as group:
Expand Down Expand Up @@ -954,7 +954,7 @@ def test_ingestion_with_batch_updates(tmp_path):
gt_i, gt_d = get_groundtruth(dataset_dir, k)

for index_type, index_class in zip(INDEXES, INDEX_CLASSES):
minimum_accuracy = 0.85 if index_type == "IVF_PQ" else 0.99
minimum_accuracy = 0.84 if index_type == "IVF_PQ" else 0.99

index_uri = os.path.join(tmp_path, f"array_{index_type}")
index = ingest(
Expand Down Expand Up @@ -992,9 +992,6 @@ def test_ingestion_with_batch_updates(tmp_path):
index_uri = move_local_index_to_new_location(index_uri)
index = index_class(uri=index_uri)

# if index_type == "IVF_PQ":
# # TODO(SC-48888): Fix consolidation for IVF_PQ.
# continue
index = index.consolidate_updates()
_, result = index.query(queries, k=k, nprobe=nprobe)
assert accuracy(result, gt_i, updated_ids=updated_ids) > minimum_accuracy
Expand Down Expand Up @@ -1115,9 +1112,6 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
assert accuracy(result, gt_i) == 1.0

# Consolidate updates
# if index_type == "IVF_PQ":
# # TODO(SC-48888): Fix consolidation for IVF_PQ.
# continue
index = index.consolidate_updates()

ingestion_timestamps, base_sizes = load_metadata(index_uri)
Expand Down Expand Up @@ -1188,10 +1182,12 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
second_num_edges = num_edges_history[1]

# Clear history before the latest ingestion
latest_ingestion_timestamp = index.latest_ingestion_timestamp
assert index.latest_ingestion_timestamp == 102
Index.clear_history(
uri=index_uri, timestamp=index.latest_ingestion_timestamp - 1
)
# With type-erased indexes, we cannot call clear_history() while the index is open because they
# open up a TileDB Array during query(). Deleting fragments while the array is open is not allowed.
index = None
Index.clear_history(uri=index_uri, timestamp=latest_ingestion_timestamp - 1)

with tiledb.Group(index_uri, "r") as group:
assert metadata_to_list(group, "ingestion_timestamps") == [102]
Expand Down Expand Up @@ -1234,7 +1230,12 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0

# Clear all history
Index.clear_history(uri=index_uri, timestamp=index.latest_ingestion_timestamp)
latest_ingestion_timestamp = index.latest_ingestion_timestamp
assert index.latest_ingestion_timestamp == 102
# With type-erased indexes, we cannot call clear_history() while the index is open because they
# open up a TileDB Array during query(). Deleting fragments while the array is open is not allowed.
index = None
Index.clear_history(uri=index_uri, timestamp=latest_ingestion_timestamp)
index = index_class(uri=index_uri, timestamp=1)
_, result = index.query(queries, k=k, nprobe=partitions)
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
Expand Down Expand Up @@ -1320,9 +1321,6 @@ def test_ingestion_with_additions_and_timetravel(tmp_path):
_, result = index.query(queries, k=k, nprobe=partitions, l_search=k * 2)
assert 0.45 < accuracy(result, gt_i)

# if index_type == "IVF_PQ":
# # TODO(SC-48888): Fix consolidation for IVF_PQ.
# continue
index = index.consolidate_updates()
_, result = index.query(queries, k=k, nprobe=partitions, l_search=k * 2)
assert 0.45 < accuracy(result, gt_i)
Expand Down Expand Up @@ -1834,7 +1832,11 @@ def test_ivf_flat_ingestion_with_training_source_uri_tdb(tmp_path):
)

# Clear the index history, load, update, and query.
Index.clear_history(uri=index_uri, timestamp=index.latest_ingestion_timestamp - 1)
# With type-erased indexes, we cannot call clear_history() while the index is open because they
# open up a TileDB Array during query(). Deleting fragments while the array is open is not allowed.
latest_ingestion_timestamp = index.latest_ingestion_timestamp
index = None
Index.clear_history(uri=index_uri, timestamp=latest_ingestion_timestamp - 1)

index = IVFFlatIndex(uri=index_uri)

Expand Down
Loading

0 comments on commit 3e5797f

Please sign in to comment.