Skip to content

Commit

Permalink
refactor & add metadata predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jun 9, 2024
1 parent d4b3028 commit 0f0b90d
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
get_embedding_metadata,
get_embedding_metadata_by_name,
)
from ._embedding_search import find_nearest_embeddings
from ._embedding_search import find_nearest_obs, predict_obs_metadata

__all__ = [
"get_embedding",
"get_embedding_metadata",
"get_embedding_metadata_by_name",
"get_all_available_embeddings",
"get_all_census_versions_with_embedding",
"find_nearest_embeddings",
"find_nearest_obs",
"predict_obs_metadata",
]
Original file line number Diff line number Diff line change
@@ -1,86 +1,147 @@
"""Nearest-neighbor search based on vector index of Census embeddings."""

from typing import Any, Dict, NamedTuple
from contextlib import ExitStack
from typing import Any, Dict, List, NamedTuple, Optional, Sequence

import anndata as ad
import numpy as np
import numpy.typing as npt
import pandas as pd
import tiledb.vector_search as vs
import tiledbsoma as soma

from .._experiment import _get_experiment_name
from ._embedding import get_embedding_metadata_by_name
from .._open import open_soma

CENSUS_EMBEDDINGS_INDEX_URI_FSTR = (
"s3://cellxgene-contrib-public/contrib/cell-census/soma/{census_version}/indexes/{embedding_id}"
)
CENSUS_EMBEDDINGS_INDEX_REGION = "us-west-2"


class NearestEmbeddings(NamedTuple):
"""Results of a nearest-neighbor search for Census embeddings."""
class NeighborObs(NamedTuple):
"""Results of nearest-neighbor search for Census obs embeddings."""

distances: npt.NDArray[np.float32]
"""
Distances to the nearest neighbors for each query embedding (q by k, where q is the number of
query embeddings and k is the desired number of neighbors).
Distances to the nearest neighbors for each query obs embedding (q by k, where q is the number
of query embeddings and k is the desired number of neighbors). The distance metric is
implementation-dependent.
"""

soma_joinids: npt.NDArray[np.int64]
neighbor_ids: npt.NDArray[np.int64]
"""
IDs of the nearest neighbors for each query embedding (q by k).
obs soma_joinid's of the nearest neighbors for each query embedding (q by k).
"""

query_ids: npt.NDArray[np.int64]
"""
obs soma_joinid's of the original query cells (q by 1).
"""

def find_nearest_embeddings(

def find_nearest_obs(
embedding_metadata: Dict[str, Any],
query: ad.AnnData,
embedding_name: str,
organism: str,
census_version: str,
embedding_type: str = "obs_embedding",
k: int = 10,
nprobe: int = 100,
memory_GiB: int = 4,
**kwargs: Dict[str, Any],
) -> NearestEmbeddings:
"""Search Census embeddings for the nearest neighbors of query embeddings.
) -> NeighborObs:
"""Search Census for similar obs (cells) based on nearest neighbors in embedding space.
Args:
embedding_metadata:
Information about the embedding to search, as found by
:func:`get_embedding_metadata_by_name`.
query:
AnnData object with an obms layer containing the query embeddings.
embedding_name, organism, census_version, embedding_type:
Identify the embedding to search, as given to :func:`get_embedding_metadata_by_name`
or :func:`get_anndata`. The query obsm layer must match ``embedding_name``.
AnnData object with an obsm layer embedding the query cells. The obsm layer name
matches ``embedding_metadata["embedding_name"]`` (e.g. scvi, geneformer).
k:
Number of nearest neighbors to return for each query embedding.
Number of nearest neighbors to return for each query obs.
nprobe:
Sensitivity parameter; defaults to 100 (roughly N^0.25 where N is the number of Census
cells) for a thorough search. Decrease for faster but less accurate search.
memory_GiB:
Memory budget for the search index, in gibibytes; defaults to 4 GiB.
"""
assert (
embedding_type == "obs_embedding"
), "cellxgene_census.experimental.find_nearest_embeddings: only embedding_type='obs_embedding' is currently supported"
# resolve embedding_name
experiment_name = _get_experiment_name(organism)
emb_metadata = get_embedding_metadata_by_name(embedding_name, experiment_name, census_version, embedding_type)
embedding_name = embedding_metadata["embedding_name"]
n_features = embedding_metadata["n_features"]

# validate query (expected obsm layer exists with the expected dimensionality)
if embedding_name not in query.obsm:
raise ValueError(f"Query does not have the expected layer {embedding_name}")
if query.obsm[embedding_name].shape[1] != emb_metadata["n_features"]:
if query.obsm[embedding_name].shape[1] != n_features:
raise ValueError(
f"Query embedding {embedding_name} has {query.obsm[embedding_name].shape[1]} features, expected {emb_metadata['n_features']}"
f"Query embedding {embedding_name} has {query.obsm[embedding_name].shape[1]} features, expected {n_features}"
)

# formulate index URI and run query
index_uri = CENSUS_EMBEDDINGS_INDEX_URI_FSTR.format(census_version=census_version, embedding_id=emb_metadata["id"])
memory_vectors = memory_GiB * (2**30) // (4 * emb_metadata["n_features"]) # number of float32 vectors
index_uri = CENSUS_EMBEDDINGS_INDEX_URI_FSTR.format(
census_version=embedding_metadata["census_version"], embedding_id=embedding_metadata["id"]
)
memory_vectors = memory_GiB * (2**30) // (4 * n_features) # number of float32 vectors
index = vs.ivf_flat_index.IVFFlatIndex(
uri=index_uri,
config={"vfs.s3.region": CENSUS_EMBEDDINGS_INDEX_REGION, "vfs.s3.no_sign_request": "true"},
memory_budget=memory_vectors,
)
distances, soma_joinids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs)
distances, neighbor_ids = index.query(query.obsm[embedding_name], k=k, nprobe=nprobe, **kwargs)

return NeighborObs(distances=distances, neighbor_ids=neighbor_ids, query_ids=np.array(query.obs.soma_joinid.values))


def predict_obs_metadata(
embedding_metadata: Dict[str, Any],
neighbors: NeighborObs,
column_names: Sequence[str],
measurement: Optional[soma.Measurement] = None,
) -> pd.DataFrame:
"""Predict obs metadata attributes for the query based on the embedding nearest neighbors.
Args:
embedding_metadata:
Information about the embedding to search, as found by
:func:`get_embedding_metadata_by_name`.
neighbors:
Results of a ``find_nearest_obs`` search.
column_names:
Desired obs metadata column names.
measurement:
Open handle for the relevant SOMAMeasurement, if available (otherwise, will be opened
internally). e.g. ``census["census_data"]["homo_sapiens"]["RNA"]`` with the relevant
Census version open.
Returns:
Pandas DataFrame with the desired column predictions, indexed by query ``soma_joinid`` .
Additionally, for each predicted column ``col``, an additional column ``col_confidence``
with a confidence score between 0 and 1.
"""
with ExitStack() as cleanup:
if measurement is None:
# open Census transiently
census = cleanup.enter_context(open_soma(census_version=embedding_metadata["census_version"]))
measurement = census["census_data"][embedding_metadata["experiment_name"]][
embedding_metadata["measurement_name"]
]

# fetch the desired obs metadata for all of the found neighbors
neighbor_obs = (
measurement.obs.read(coords=neighbors.neighbor_ids.flatten(), column_names=column_names)
.concat()
.to_pandas()
)

# step through query cells to generate prediction for each column as the plurality value
# found among its neighbors, with a confidence score based on the simple fraction (for now)
# TODO: something more intelligent for numeric columns! also use distances, etc.
out: Dict[str, List[Any]] = {col: [] for col in column_names}
out["soma_joinid"] = []
for i, query_id in enumerate(neighbors.query_ids):
out["soma_joinid"].append(query_id)
neighbors_i = neighbor_obs[neighbor_obs.index.isin(neighbors.neighbor_ids[i])]
for col in column_names:
col_value_counts = neighbors_i[col].value_counts(normalize=True)
out[col].append(col_value_counts.idxmax())
out[col + "_confidence"].append(col_value_counts.max())

return NearestEmbeddings(distances=distances, soma_joinids=soma_joinids)
return pd.DataFrame(out).set_index("soma_joinid")
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
import anndata as ad
import numpy as np
import pytest
import tiledb

import cellxgene_census
from cellxgene_census.experimental import find_nearest_embeddings
from cellxgene_census.experimental import find_nearest_obs, get_embedding_metadata_by_name


@pytest.mark.experimental
@pytest.mark.live_corpus
def test_embeddings_search(true_neighbors: Dict[str, Any], query_anndata: ad.AnnData) -> None:
rslt = find_nearest_embeddings(
def test_embeddings_search(
emb_metadata: Dict[str, Any], true_neighbors: Dict[str, Any], query_anndata: ad.AnnData
) -> None:
rslt = find_nearest_obs(
emb_metadata,
query_anndata,
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
k=TRUE_NEAREST_NEIGHBORS_K,
nprobe=25,
)
Expand Down Expand Up @@ -45,38 +46,35 @@ def test_embeddings_search(true_neighbors: Dict[str, Any], query_anndata: ad.Ann

@pytest.mark.experimental
@pytest.mark.live_corpus
def test_embeddings_search_errors(query_anndata: ad.AnnData) -> None:
# unknown embedding
with pytest.raises(ValueError, match="No embeddings found"):
find_nearest_embeddings(
query_anndata,
"BOGUS123",
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
)
def test_embeddings_search_errors(emb_metadata: Dict[str, Any], query_anndata: ad.AnnData) -> None:
# bogus embedding metadata
emb_metadata2 = emb_metadata.copy()
emb_metadata2["embedding_id"] = "BOGUS123"
with pytest.raises(tiledb.TileDBError):
find_nearest_obs(emb_metadata2, query_anndata)
# query anndata missing the embedding layer
bogus_ad = query_anndata.copy()
bogus_ad.obsm.pop(TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME)
with pytest.raises(ValueError, match="Query does not have"):
find_nearest_embeddings(
bogus_ad,
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
)
find_nearest_obs(emb_metadata, bogus_ad)
# embedding layer has wrong number of features
bogus_ad = query_anndata.copy()
bogus_ad.obsm[TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME] = np.zeros((len(bogus_ad), 42))
with pytest.raises(ValueError, match="features, expected"):
find_nearest_embeddings(
bogus_ad,
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
)
find_nearest_obs(emb_metadata, bogus_ad)
return


@pytest.fixture(scope="module")
def emb_metadata() -> Dict[str, Any]:
return get_embedding_metadata_by_name(
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
"obs_embedding",
)


@pytest.fixture(scope="module")
def true_neighbors() -> Dict[int, List[Dict[str, Any]]]:
ans = {}
Expand Down

0 comments on commit 0f0b90d

Please sign in to comment.