From 0f0b90d1019bb500fd591ac556038580b6b977c7 Mon Sep 17 00:00:00 2001 From: Mike Lin Date: Sat, 8 Jun 2024 22:31:13 -1000 Subject: [PATCH] refactor & add metadata predictor --- .../cellxgene_census/experimental/__init__.py | 5 +- .../experimental/_embedding_search.py | 127 +++++++++++++----- .../experimental/test_embeddings_search.py | 52 ++++--- 3 files changed, 122 insertions(+), 62 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/__init__.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/__init__.py index ce6c3384f..93f0c938c 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/__init__.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/__init__.py @@ -7,7 +7,7 @@ get_embedding_metadata, get_embedding_metadata_by_name, ) -from ._embedding_search import find_nearest_embeddings +from ._embedding_search import find_nearest_obs, predict_obs_metadata __all__ = [ "get_embedding", @@ -15,5 +15,6 @@ "get_embedding_metadata_by_name", "get_all_available_embeddings", "get_all_census_versions_with_embedding", - "find_nearest_embeddings", + "find_nearest_obs", + "predict_obs_metadata", ] diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py index 7ae6aed29..09911022e 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py @@ -1,14 +1,16 @@ """Nearest-neighbor search based on vector index of Census embeddings.""" -from typing import Any, Dict, NamedTuple +from contextlib import ExitStack +from typing import Any, Dict, List, NamedTuple, Optional, Sequence import anndata as ad import numpy as np import numpy.typing as npt +import pandas as pd import tiledb.vector_search as vs +import tiledbsoma as soma -from .._experiment import _get_experiment_name -from ._embedding import get_embedding_metadata_by_name +from .._open import open_soma CENSUS_EMBEDDINGS_INDEX_URI_FSTR = ( "s3://cellxgene-contrib-public/contrib/cell-census/soma/{census_version}/indexes/{embedding_id}" @@ -16,71 +18,130 @@ CENSUS_EMBEDDINGS_INDEX_REGION = "us-west-2" -class NearestEmbeddings(NamedTuple): - """Results of a nearest-neighbor search for Census embeddings.""" +class NeighborObs(NamedTuple): + """Results of nearest-neighbor search for Census obs embeddings.""" distances: npt.NDArray[np.float32] """ - Distances to the nearest neighbors for each query embedding (q by k, where q is the number of - query embeddings and k is the desired number of neighbors). + Distances to the nearest neighbors for each query obs embedding (q by k, where q is the number + of query embeddings and k is the desired number of neighbors). The distance metric is + implementation-dependent. """ - soma_joinids: npt.NDArray[np.int64] + neighbor_ids: npt.NDArray[np.int64] """ - IDs of the nearest neighbors for each query embedding (q by k). + obs soma_joinid's of the nearest neighbors for each query embedding (q by k). """ + query_ids: npt.NDArray[np.int64] + """ + obs soma_joinid's of the original query cells (q by 1). + """ -def find_nearest_embeddings( + +def find_nearest_obs( + embedding_metadata: Dict[str, Any], query: ad.AnnData, - embedding_name: str, - organism: str, - census_version: str, - embedding_type: str = "obs_embedding", k: int = 10, nprobe: int = 100, memory_GiB: int = 4, **kwargs: Dict[str, Any], -) -> NearestEmbeddings: - """Search Census embeddings for the nearest neighbors of query embeddings. +) -> NeighborObs: + """Search Census for similar obs (cells) based on nearest neighbors in embedding space. Args: + embedding_metadata: + Information about the embedding to search, as found by + :func:`get_embedding_metadata_by_name`. query: - AnnData object with an obms layer containing the query embeddings. - embedding_name, organism, census_version, embedding_type: - Identify the embedding to search, as given to :func:`get_embedding_metadata_by_name` - or :func:`get_anndata`. The query obsm layer must match ``embedding_name``. + AnnData object with an obsm layer embedding the query cells. The obsm layer name + matches ``embedding_metadata["embedding_name"]`` (e.g. scvi, geneformer). k: - Number of nearest neighbors to return for each query embedding. + Number of nearest neighbors to return for each query obs. nprobe: Sensitivity parameter; defaults to 100 (roughly N^0.25 where N is the number of Census cells) for a thorough search. Decrease for faster but less accurate search. memory_GiB: Memory budget for the search index, in gibibytes; defaults to 4 GiB. """ - assert ( - embedding_type == "obs_embedding" - ), "cellxgene_census.experimental.find_nearest_embeddings: only embedding_type='obs_embedding' is currently supported" - # resolve embedding_name - experiment_name = _get_experiment_name(organism) - emb_metadata = get_embedding_metadata_by_name(embedding_name, experiment_name, census_version, embedding_type) + embedding_name = embedding_metadata["embedding_name"] + n_features = embedding_metadata["n_features"] # validate query (expected obsm layer exists with the expected dimensionality) if embedding_name not in query.obsm: raise ValueError(f"Query does not have the expected layer {embedding_name}") - if query.obsm[embedding_name].shape[1] != emb_metadata["n_features"]: + if query.obsm[embedding_name].shape[1] != n_features: raise ValueError( - f"Query embedding {embedding_name} has {query.obsm[embedding_name].shape[1]} features, expected {emb_metadata['n_features']}" + f"Query embedding {embedding_name} has {query.obsm[embedding_name].shape[1]} features, expected {n_features}" ) # formulate index URI and run query - index_uri = CENSUS_EMBEDDINGS_INDEX_URI_FSTR.format(census_version=census_version, embedding_id=emb_metadata["id"]) - memory_vectors = memory_GiB * (2**30) // (4 * emb_metadata["n_features"]) # number of float32 vectors + index_uri = CENSUS_EMBEDDINGS_INDEX_URI_FSTR.format( + census_version=embedding_metadata["census_version"], embedding_id=embedding_metadata["id"] + ) + memory_vectors = memory_GiB * (2**30) // (4 * n_features) # number of float32 vectors index = vs.ivf_flat_index.IVFFlatIndex( uri=index_uri, config={"vfs.s3.region": CENSUS_EMBEDDINGS_INDEX_REGION, "vfs.s3.no_sign_request": "true"}, memory_budget=memory_vectors, ) - distances, soma_joinids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs) + distances, neighbor_ids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs) + + return NeighborObs(distances=distances, neighbor_ids=neighbor_ids, query_ids=np.array(query.obs.soma_joinid.values)) + + +def predict_obs_metadata( + embedding_metadata: Dict[str, Any], + neighbors: NeighborObs, + column_names: Sequence[str], + measurement: Optional[soma.Measurement] = None, +) -> pd.DataFrame: + """Predict obs metadata attributes for the query based on the embedding nearest neighbors. + + Args: + embedding_metadata: + Information about the embedding to search, as found by + :func:`get_embedding_metadata_by_name`. + neighbors: + Results of a ``find_nearest_obs`` search. + column_names: + Desired obs metadata column names. + measurement: + Open handle for the relevant SOMAMeasurement, if available (otherwise, will be opened + internally). e.g. ``census["census_data"]["homo_sapiens"]["RNA"]`` with the relevant + Census version open. + + Returns: + Pandas DataFrame with the desired column predictions, indexed by query ``soma_joinid`` . + Additionally, for each predicted column ``col``, an additional column ``col_confidence`` + with a confidence score between 0 and 1. + """ + with ExitStack() as cleanup: + if measurement is None: + # open Census transiently + census = cleanup.enter_context(open_soma(census_version=embedding_metadata["census_version"])) + measurement = census["census_data"][embedding_metadata["experiment_name"]][ + embedding_metadata["measurement_name"] + ] + + # fetch the desired obs metadata for all of the found neighbors + neighbor_obs = ( + measurement.obs.read(coords=neighbors.neighbor_ids.flatten(), column_names=column_names) + .concat() + .to_pandas() + ) + + # step through query cells to generate prediction for each column as the plurality value + # found among its neighbors, with a confidence score based on the simple fraction (for now) + # TODO: something more intelligent for numeric columns! also use distances, etc. + out: Dict[str, List[Any]] = {col: [] for col in column_names} + out["soma_joinid"] = [] + for i, query_id in enumerate(neighbors.query_ids): + out["soma_joinid"].append(query_id) + neighbors_i = neighbor_obs[neighbor_obs.index.isin(neighbors.neighbor_ids[i])] + for col in column_names: + col_value_counts = neighbors_i[col].value_counts(normalize=True) + out[col].append(col_value_counts.idxmax()) + out[col + "_confidence"].append(col_value_counts.max()) - return NearestEmbeddings(distances=distances, soma_joinids=soma_joinids) + return pd.DataFrame(out).set_index("soma_joinid") diff --git a/api/python/cellxgene_census/tests/experimental/test_embeddings_search.py b/api/python/cellxgene_census/tests/experimental/test_embeddings_search.py index 6d8a2a204..6984348dc 100644 --- a/api/python/cellxgene_census/tests/experimental/test_embeddings_search.py +++ b/api/python/cellxgene_census/tests/experimental/test_embeddings_search.py @@ -4,19 +4,20 @@ import anndata as ad import numpy as np import pytest +import tiledb import cellxgene_census -from cellxgene_census.experimental import find_nearest_embeddings +from cellxgene_census.experimental import find_nearest_obs, get_embedding_metadata_by_name @pytest.mark.experimental @pytest.mark.live_corpus -def test_embeddings_search(true_neighbors: Dict[str, Any], query_anndata: ad.AnnData) -> None: - rslt = find_nearest_embeddings( +def test_embeddings_search( + emb_metadata: Dict[str, Any], true_neighbors: Dict[str, Any], query_anndata: ad.AnnData +) -> None: + rslt = find_nearest_obs( + emb_metadata, query_anndata, - TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME, - TRUE_NEAREST_NEIGHBORS_ORGANISM, - TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, k=TRUE_NEAREST_NEIGHBORS_K, nprobe=25, ) @@ -45,38 +46,35 @@ def test_embeddings_search(true_neighbors: Dict[str, Any], query_anndata: ad.Ann @pytest.mark.experimental @pytest.mark.live_corpus -def test_embeddings_search_errors(query_anndata: ad.AnnData) -> None: - # unknown embedding - with pytest.raises(ValueError, match="No embeddings found"): - find_nearest_embeddings( - query_anndata, - "BOGUS123", - TRUE_NEAREST_NEIGHBORS_ORGANISM, - TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, - ) +def test_embeddings_search_errors(emb_metadata: Dict[str, Any], query_anndata: ad.AnnData) -> None: + # bogus embedding metadata + emb_metadata2 = emb_metadata.copy() + emb_metadata2["embedding_id"] = "BOGUS123" + with pytest.raises(tiledb.TileDBError): + find_nearest_obs(emb_metadata2, query_anndata) # query anndata missing the embedding layer bogus_ad = query_anndata.copy() bogus_ad.obsm.pop(TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME) with pytest.raises(ValueError, match="Query does not have"): - find_nearest_embeddings( - bogus_ad, - TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME, - TRUE_NEAREST_NEIGHBORS_ORGANISM, - TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, - ) + find_nearest_obs(emb_metadata, bogus_ad) # embedding layer has wrong number of features bogus_ad = query_anndata.copy() bogus_ad.obsm[TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME] = np.zeros((len(bogus_ad), 42)) with pytest.raises(ValueError, match="features, expected"): - find_nearest_embeddings( - bogus_ad, - TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME, - TRUE_NEAREST_NEIGHBORS_ORGANISM, - TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, - ) + find_nearest_obs(emb_metadata, bogus_ad) return +@pytest.fixture(scope="module") +def emb_metadata() -> Dict[str, Any]: + return get_embedding_metadata_by_name( + TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME, + TRUE_NEAREST_NEIGHBORS_ORGANISM, + TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, + "obs_embedding", + ) + + @pytest.fixture(scope="module") def true_neighbors() -> Dict[int, List[Dict[str, Any]]]: ans = {}