Skip to content

Commit

Permalink
write feature_vector array when we update the index
Browse files Browse the repository at this point in the history
  • Loading branch information
jparismorgan committed Jul 16, 2024
1 parent 38b12c6 commit a0c0830
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 58 deletions.
87 changes: 52 additions & 35 deletions apis/python/src/tiledb/vector_search/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,10 @@ def ingest_type_erased(
verbose: bool = False,
trace_id: Optional[str] = None,
):
print("[ingestion@ingest_type_erased] retrain_index", retrain_index)
print("[ingestion@ingest_type_erased] size", size)
print("[ingestion@ingest_type_erased] batch", batch)
print("[ingestion@ingest_type_erased] dimensions", dimensions)
import numpy as np

import tiledb.cloud
Expand All @@ -1613,41 +1617,14 @@ def ingest_type_erased(
verbose=verbose,
trace_id=trace_id,
)

if not retrain_index and index_type == "IVF_PQ":
print(
"[ingestion@ingest_type_erased] additions_vectors:",
additions_vectors,
)
print(
"[ingestion@ingest_type_erased] additions_external_ids:",
additions_external_ids,
)
ctx = vspy.Ctx(config)
index = vspy.IndexIVFPQ(ctx, index_group_uri)
if (
additions_vectors is not None
or additions_external_ids is not None
or updated_ids is not None
):
vectors_to_add = vspy.FeatureVectorArray(
np.transpose(additions_vectors)
if additions_vectors is not None
else np.array([[]], dtype=vector_type),
np.transpose(additions_external_ids)
if additions_external_ids is not None
else np.array([], dtype=np.uint64),
)
vector_ids_to_remove = vspy.FeatureVector(
updated_ids
if updated_ids is not None
else np.array([], np.uint64)
)
index.update(vectors_to_add, vector_ids_to_remove)
index.write_index(
ctx, index_group_uri, to_temporal_policy(index_timestamp)
)
return
print(
"[ingestion@ingest_type_erased] additions_vectors:",
additions_vectors,
)
print(
"[ingestion@ingest_type_erased] additions_external_ids:",
additions_external_ids,
)

temp_data_group_uri = f"{index_group_uri}/{PARTIAL_WRITE_ARRAY_DIR}"
temp_data_group = tiledb.Group(temp_data_group_uri, "w")
Expand All @@ -1674,7 +1651,14 @@ def ingest_type_erased(
part_end = part + batch
if part_end > size:
part_end = size

# First we get each vector and it's external id from the input data.
print("[ingestion@ingest_type_erased] source_uri:", source_uri)
print("[ingestion@ingest_type_erased] source_type:", source_type)
print("[ingestion@ingest_type_erased] vector_type:", vector_type)
print("[ingestion@ingest_type_erased] dimensions:", dimensions)
print("[ingestion@ingest_type_erased] part:", part)
print("[ingestion@ingest_type_erased] part_end:", part_end)
in_vectors = read_input_vectors(
source_uri=source_uri,
source_type=source_type,
Expand All @@ -1686,6 +1670,7 @@ def ingest_type_erased(
verbose=verbose,
trace_id=trace_id,
)
print("[ingestion@ingest_type_erased] in_vectors:", in_vectors)
external_ids = read_external_ids(
external_ids_uri=external_ids_uri,
external_ids_type=external_ids_type,
Expand All @@ -1695,6 +1680,7 @@ def ingest_type_erased(
verbose=verbose,
trace_id=trace_id,
)
print("[ingestion@ingest_type_erased] external_ids:", external_ids)

# Then check if the external id is in the updated ids.
updates_filter = np.in1d(
Expand All @@ -1703,6 +1689,14 @@ def ingest_type_erased(
# We only keep the vectors and external ids that are not in the updated ids.
in_vectors = in_vectors[updates_filter]
external_ids = external_ids[updates_filter]
print(
"[ingestion@ingest_type_erased] in_vectors after filter:",
in_vectors,
)
print(
"[ingestion@ingest_type_erased] external_ids after filter:",
external_ids,
)
vector_len = len(in_vectors)
if vector_len > 0:
end_offset = write_offset + vector_len
Expand Down Expand Up @@ -1736,6 +1730,29 @@ def ingest_type_erased(
parts_array.close()
ids_array.close()

if index_type == "IVF_PQ" and not retrain_index:
ctx = vspy.Ctx(config)
index = vspy.IndexIVFPQ(ctx, index_group_uri)
if (
additions_vectors is not None
or additions_external_ids is not None
or updated_ids is not None
):
vectors_to_add = vspy.FeatureVectorArray(
np.transpose(additions_vectors)
if additions_vectors is not None
else np.array([[]], dtype=vector_type),
np.transpose(additions_external_ids)
if additions_external_ids is not None
else np.array([], dtype=np.uint64),
)
vector_ids_to_remove = vspy.FeatureVector(
updated_ids if updated_ids is not None else np.array([], np.uint64)
)
index.update(vectors_to_add, vector_ids_to_remove)
index.write_index(ctx, index_group_uri, to_temporal_policy(index_timestamp))
return

# Now that we've ingested the vectors and their IDs, train the index with the data.
ctx = vspy.Ctx(config)
if index_type == "VAMANA":
Expand Down
72 changes: 71 additions & 1 deletion apis/python/test/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def test_vamana_index(tmp_path):
# During the first ingestion we overwrite the metadata and end up with a single base size and ingestion timestamp.
ingestion_timestamps, base_sizes = load_metadata(uri)
assert base_sizes == [5]
assert len(ingestion_timestamps) == 1
timestamp_5_minutes_from_now = int((time.time() + 5 * 60) * 1000)
timestamp_5_minutes_ago = int((time.time() - 5 * 60) * 1000)
assert (
Expand Down Expand Up @@ -316,6 +317,9 @@ def test_ivf_pq_index(tmp_path):
os.rmdir(uri)
vector_type = np.float32

print(
"[test_index] ivf_pq_index.create() --------------------------------------------------------"
)
index = ivf_pq_index.create(
uri=uri,
dimensions=3,
Expand All @@ -342,6 +346,9 @@ def test_ivf_pq_index(tmp_path):
update_vectors[2] = np.array([2, 2, 2], dtype=np.dtype(np.float32))
update_vectors[3] = np.array([3, 3, 3], dtype=np.dtype(np.float32))
update_vectors[4] = np.array([4, 4, 4], dtype=np.dtype(np.float32))
print(
"[test_index] index.update_batch() --------------------------------------------------------"
)
index.update_batch(
vectors=update_vectors,
external_ids=np.array([0, 1, 2, 3, 4], dtype=np.dtype(np.uint32)),
Expand All @@ -350,7 +357,70 @@ def test_ivf_pq_index(tmp_path):
index, np.array([[2, 2, 2]], dtype=np.float32), 2, [[0, 3]], [[2, 1]]
)

# TODO(paris): Add tests for consolidation once we enable it.
# By default we do not re-train the index. This means we won't be able to find any results.
print(
"[test_index] index.consolidate_updates() --------------------------------------------------------"
)
index = index.consolidate_updates(retrain_index=False)
for i in range(5):
distances, ids = index.query(np.array([[i, i, i]], dtype=np.float32), k=1)
assert np.array_equal(ids, np.array([[MAX_UINT64]], dtype=np.float32))
assert np.array_equal(distances, np.array([[MAX_FLOAT32]], dtype=np.float32))

# We can retrain the index and find the results. Update ID 4 to 44 while we do that.
print(
"[test_index] index.delete() --------------------------------------------------------"
)
index.delete(external_id=4)
print(
"[test_index] index.update() --------------------------------------------------------"
)
index.update(vector=np.array([4, 4, 4], dtype=np.dtype(np.float32)), external_id=44)
print(
"[test_index] index.consolidate_updates() --------------------------------------------------------"
)
index = index.consolidate_updates(retrain_index=True)
return
# During the first ingestion we overwrite the metadata and end up with a single base size and ingestion timestamp.
ingestion_timestamps, base_sizes = load_metadata(uri)
assert base_sizes == [5]
assert len(ingestion_timestamps) == 1
timestamp_5_minutes_from_now = int((time.time() + 5 * 60) * 1000)
timestamp_5_minutes_ago = int((time.time() - 5 * 60) * 1000)
assert (
ingestion_timestamps[0] > timestamp_5_minutes_ago
and ingestion_timestamps[0] < timestamp_5_minutes_from_now
)

# Test that we can query with multiple query vectors.
for i in range(5):
query_and_check_distances(
index,
np.array([[i, i, i], [i, i, i]], dtype=np.float32),
1,
[[0], [0]],
[[i], [i]],
)

# Test that we can query with k > 1.
query_and_check_distances(
index, np.array([[0, 0, 0]], dtype=np.float32), 2, [[0, 3]], [[0, 1]]
)

# Test that we can query with multiple query vectors and k > 1.
query_and_check_distances(
index,
np.array([[0, 0, 0], [4, 4, 4]], dtype=np.float32),
2,
[[0, 3], [0, 3]],
[[0, 1], [4, 3]],
)

vfs = tiledb.VFS()

assert vfs.dir_size(uri) > 0
Index.delete_index(uri=uri, config={})
assert vfs.dir_size(uri) == 0


def test_delete_invalid_index(tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion apis/python/test/test_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def test_ingestion_timetravel(tmp_path):
timestamp=20,
)

index = index.consolidate_updates()
index = index.consolidate_updates(retrain_index=True)

# We still have no results before timestamp 10.
query_and_check_equals(
Expand Down
Loading

0 comments on commit a0c0830

Please sign in to comment.