-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
122 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
127 changes: 94 additions & 33 deletions
127
api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,147 @@ | ||
"""Nearest-neighbor search based on vector index of Census embeddings.""" | ||
|
||
from typing import Any, Dict, NamedTuple | ||
from contextlib import ExitStack | ||
from typing import Any, Dict, List, NamedTuple, Optional, Sequence | ||
|
||
import anndata as ad | ||
import numpy as np | ||
import numpy.typing as npt | ||
import pandas as pd | ||
import tiledb.vector_search as vs | ||
import tiledbsoma as soma | ||
|
||
from .._experiment import _get_experiment_name | ||
from ._embedding import get_embedding_metadata_by_name | ||
from .._open import open_soma | ||
|
||
CENSUS_EMBEDDINGS_INDEX_URI_FSTR = ( | ||
"s3://cellxgene-contrib-public/contrib/cell-census/soma/{census_version}/indexes/{embedding_id}" | ||
) | ||
CENSUS_EMBEDDINGS_INDEX_REGION = "us-west-2" | ||
|
||
|
||
class NearestEmbeddings(NamedTuple): | ||
"""Results of a nearest-neighbor search for Census embeddings.""" | ||
class NeighborObs(NamedTuple): | ||
"""Results of nearest-neighbor search for Census obs embeddings.""" | ||
|
||
distances: npt.NDArray[np.float32] | ||
""" | ||
Distances to the nearest neighbors for each query embedding (q by k, where q is the number of | ||
query embeddings and k is the desired number of neighbors). | ||
Distances to the nearest neighbors for each query obs embedding (q by k, where q is the number | ||
of query embeddings and k is the desired number of neighbors). The distance metric is | ||
implementation-dependent. | ||
""" | ||
|
||
soma_joinids: npt.NDArray[np.int64] | ||
neighbor_ids: npt.NDArray[np.int64] | ||
""" | ||
IDs of the nearest neighbors for each query embedding (q by k). | ||
obs soma_joinid's of the nearest neighbors for each query embedding (q by k). | ||
""" | ||
|
||
query_ids: npt.NDArray[np.int64] | ||
""" | ||
obs soma_joinid's of the original query cells (q by 1). | ||
""" | ||
|
||
def find_nearest_embeddings( | ||
|
||
def find_nearest_obs( | ||
embedding_metadata: Dict[str, Any], | ||
query: ad.AnnData, | ||
embedding_name: str, | ||
organism: str, | ||
census_version: str, | ||
embedding_type: str = "obs_embedding", | ||
k: int = 10, | ||
nprobe: int = 100, | ||
memory_GiB: int = 4, | ||
**kwargs: Dict[str, Any], | ||
) -> NearestEmbeddings: | ||
"""Search Census embeddings for the nearest neighbors of query embeddings. | ||
) -> NeighborObs: | ||
"""Search Census for similar obs (cells) based on nearest neighbors in embedding space. | ||
Args: | ||
embedding_metadata: | ||
Information about the embedding to search, as found by | ||
:func:`get_embedding_metadata_by_name`. | ||
query: | ||
AnnData object with an obms layer containing the query embeddings. | ||
embedding_name, organism, census_version, embedding_type: | ||
Identify the embedding to search, as given to :func:`get_embedding_metadata_by_name` | ||
or :func:`get_anndata`. The query obsm layer must match ``embedding_name``. | ||
AnnData object with an obsm layer embedding the query cells. The obsm layer name | ||
matches ``embedding_metadata["embedding_name"]`` (e.g. scvi, geneformer). | ||
k: | ||
Number of nearest neighbors to return for each query embedding. | ||
Number of nearest neighbors to return for each query obs. | ||
nprobe: | ||
Sensitivity parameter; defaults to 100 (roughly N^0.25 where N is the number of Census | ||
cells) for a thorough search. Decrease for faster but less accurate search. | ||
memory_GiB: | ||
Memory budget for the search index, in gibibytes; defaults to 4 GiB. | ||
""" | ||
assert ( | ||
embedding_type == "obs_embedding" | ||
), "cellxgene_census.experimental.find_nearest_embeddings: only embedding_type='obs_embedding' is currently supported" | ||
# resolve embedding_name | ||
experiment_name = _get_experiment_name(organism) | ||
emb_metadata = get_embedding_metadata_by_name(embedding_name, experiment_name, census_version, embedding_type) | ||
embedding_name = embedding_metadata["embedding_name"] | ||
n_features = embedding_metadata["n_features"] | ||
|
||
# validate query (expected obsm layer exists with the expected dimensionality) | ||
if embedding_name not in query.obsm: | ||
raise ValueError(f"Query does not have the expected layer {embedding_name}") | ||
if query.obsm[embedding_name].shape[1] != emb_metadata["n_features"]: | ||
if query.obsm[embedding_name].shape[1] != n_features: | ||
raise ValueError( | ||
f"Query embedding {embedding_name} has {query.obsm[embedding_name].shape[1]} features, expected {emb_metadata['n_features']}" | ||
f"Query embedding {embedding_name} has {query.obsm[embedding_name].shape[1]} features, expected {n_features}" | ||
) | ||
|
||
# formulate index URI and run query | ||
index_uri = CENSUS_EMBEDDINGS_INDEX_URI_FSTR.format(census_version=census_version, embedding_id=emb_metadata["id"]) | ||
memory_vectors = memory_GiB * (2**30) // (4 * emb_metadata["n_features"]) # number of float32 vectors | ||
index_uri = CENSUS_EMBEDDINGS_INDEX_URI_FSTR.format( | ||
census_version=embedding_metadata["census_version"], embedding_id=embedding_metadata["id"] | ||
) | ||
memory_vectors = memory_GiB * (2**30) // (4 * n_features) # number of float32 vectors | ||
index = vs.ivf_flat_index.IVFFlatIndex( | ||
uri=index_uri, | ||
config={"vfs.s3.region": CENSUS_EMBEDDINGS_INDEX_REGION, "vfs.s3.no_sign_request": "true"}, | ||
memory_budget=memory_vectors, | ||
) | ||
distances, soma_joinids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs) | ||
distances, neighbor_ids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs) | ||
|
||
return NeighborObs(distances=distances, neighbor_ids=neighbor_ids, query_ids=np.array(query.obs.soma_joinid.values)) | ||
|
||
|
||
def predict_obs_metadata( | ||
embedding_metadata: Dict[str, Any], | ||
neighbors: NeighborObs, | ||
column_names: Sequence[str], | ||
measurement: Optional[soma.Measurement] = None, | ||
) -> pd.DataFrame: | ||
"""Predict obs metadata attributes for the query based on the embedding nearest neighbors. | ||
Args: | ||
embedding_metadata: | ||
Information about the embedding to search, as found by | ||
:func:`get_embedding_metadata_by_name`. | ||
neighbors: | ||
Results of a ``find_nearest_obs`` search. | ||
column_names: | ||
Desired obs metadata column names. | ||
measurement: | ||
Open handle for the relevant SOMAMeasurement, if available (otherwise, will be opened | ||
internally). e.g. ``census["census_data"]["homo_sapiens"]["RNA"]`` with the relevant | ||
Census version open. | ||
Returns: | ||
Pandas DataFrame with the desired column predictions, indexed by query ``soma_joinid`` . | ||
Additionally, for each predicted column ``col``, an additional column ``col_confidence`` | ||
with a confidence score between 0 and 1. | ||
""" | ||
with ExitStack() as cleanup: | ||
if measurement is None: | ||
# open Census transiently | ||
census = cleanup.enter_context(open_soma(census_version=embedding_metadata["census_version"])) | ||
measurement = census["census_data"][embedding_metadata["experiment_name"]][ | ||
embedding_metadata["measurement_name"] | ||
] | ||
|
||
# fetch the desired obs metadata for all of the found neighbors | ||
neighbor_obs = ( | ||
measurement.obs.read(coords=neighbors.neighbor_ids.flatten(), column_names=column_names) | ||
.concat() | ||
.to_pandas() | ||
) | ||
|
||
# step through query cells to generate prediction for each column as the plurality value | ||
# found among its neighbors, with a confidence score based on the simple fraction (for now) | ||
# TODO: something more intelligent for numeric columns! also use distances, etc. | ||
out: Dict[str, List[Any]] = {col: [] for col in column_names} | ||
out["soma_joinid"] = [] | ||
for i, query_id in enumerate(neighbors.query_ids): | ||
out["soma_joinid"].append(query_id) | ||
neighbors_i = neighbor_obs[neighbor_obs.index.isin(neighbors.neighbor_ids[i])] | ||
for col in column_names: | ||
col_value_counts = neighbors_i[col].value_counts(normalize=True) | ||
out[col].append(col_value_counts.idxmax()) | ||
out[col + "_confidence"].append(col_value_counts.max()) | ||
|
||
return NearestEmbeddings(distances=distances, soma_joinids=soma_joinids) | ||
return pd.DataFrame(out).set_index("soma_joinid") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters