Skip to content

Commit

Permalink
permanent loc
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jun 9, 2024
1 parent 59f2ca2 commit f560172
Showing 1 changed file with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import tiledb.vector_search as vs

from .._experiment import _get_experiment_name
from .._util import _uri_join
from ._embedding import get_embedding_metadata_by_name

# FIXME: use permanent URI
CENSUS_EMBEDDINGS_INDEX_BASE_URI = "s3://mlin-census-scratch/census_embeddings_indexer/out/3540236/out/indexes/"
CENSUS_EMBEDDINGS_INDEX_URI_FSTR = (
"s3://cellxgene-contrib-public/contrib/cell-census/soma/{census_version}/indexes/{embedding_id}"
)
CENSUS_EMBEDDINGS_INDEX_REGION = "us-west-2"


Expand All @@ -39,6 +39,7 @@ def find_nearest_embeddings(
embedding_type: str = "obs_embedding",
k: int = 10,
nprobe: int = 100,
memory_GiB: int = 4,
**kwargs: Dict[str, Any],
) -> NearestEmbeddings:
"""Search Census embeddings for the nearest neighbors of query embeddings.
Expand All @@ -54,6 +55,8 @@ def find_nearest_embeddings(
nprobe:
Sensitivity parameter; defaults to 100 (roughly N^0.25 where N is the number of Census
cells) for a thorough search. Decrease for faster but less accurate search.
memory_GiB:
Memory budget for the search index, in gibibytes; defaults to 4 GiB.
"""
assert (
embedding_type == "obs_embedding"
Expand All @@ -71,10 +74,10 @@ def find_nearest_embeddings(
)

# formulate index URI and run query
index_uri = _uri_join(CENSUS_EMBEDDINGS_INDEX_BASE_URI, f"{census_version}/{emb_metadata['id']}")
# TODO: parameterize memory_budget and get advice how to set it
index_uri = CENSUS_EMBEDDINGS_INDEX_URI_FSTR.format(census_version=census_version, embedding_id=emb_metadata["id"])
memory_vectors = memory_GiB * (2**30) // (4 * emb_metadata["n_features"]) # number of float32 vectors
index = vs.ivf_flat_index.IVFFlatIndex(
uri=index_uri, config={"vfs.s3.region": CENSUS_EMBEDDINGS_INDEX_REGION}, memory_budget=4 * (2**20)
uri=index_uri, config={"vfs.s3.region": CENSUS_EMBEDDINGS_INDEX_REGION}, memory_budget=memory_vectors
)
distances, soma_joinids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs)

Expand Down

0 comments on commit f560172

Please sign in to comment.