Skip to content

Commit

Permalink
[python] similarity search API: optimize predict_obs_metadata (#1257)
Browse files Browse the repository at this point in the history
* squash for PR

* use DEFAULT_TILEDB_CONFIGURATION

* workaround

* workaround

* fix

* resolve indexes through JSONs

* lint

* API refactoring

* Update api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py

Co-authored-by: Isaac Virshup <ivirshup@gmail.com>

* fixups

* Update api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py

Co-authored-by: Isaac Virshup <ivirshup@gmail.com>

---------

Co-authored-by: Isaac Virshup <ivirshup@gmail.com>
  • Loading branch information
mlin and ivirshup authored Sep 9, 2024
1 parent fc7aefe commit 3070d6a
Showing 1 changed file with 24 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import tiledb.vector_search as vs
import tiledbsoma as soma
from scipy import sparse

from .._experiment import _get_experiment_name
from .._open import DEFAULT_TILEDB_CONFIGURATION, open_soma
Expand Down Expand Up @@ -157,12 +158,28 @@ def predict_obs_metadata(
# step through query cells to generate prediction for each column as the plurality value
# found among its neighbors, with a confidence score based on the simple fraction (for now)
# TODO: something more intelligent for numeric columns! also use distances, etc.
out: dict[str, list[Any]] = {}
for i in range(neighbors.neighbor_ids.shape[0]):
neighbors_i = neighbor_obs.loc[neighbors.neighbor_ids[i]]
for col in column_names:
col_value_counts = neighbors_i[col].value_counts(normalize=True)
out.setdefault(col, []).append(col_value_counts.idxmax())
out.setdefault(col + "_confidence", []).append(col_value_counts.max())
max_joinid = neighbor_obs.index.max()
out: dict[str, pd.Series[Any]] = {}
indices = np.broadcast_to(np.arange(neighbors.neighbor_ids.shape[0]), (10, neighbors.neighbor_ids.shape[0])).T
g = sparse.csr_matrix(
(
np.broadcast_to(1, neighbors.neighbor_ids.shape[0] * 10),
(
indices.flatten(),
neighbors.neighbor_ids.astype(np.int64).flatten(),
),
),
shape=(neighbors.neighbor_ids.shape[0], max_joinid + 1),
)
for col in column_names:
col_categorical = neighbor_obs[col].astype("category")
joinid2category = sparse.coo_matrix(
(np.broadcast_to(1, len(neighbor_obs)), (neighbor_obs.index, col_categorical.cat.codes)),
shape=(max_joinid + 1, len(col_categorical.cat.categories)),
)
counts = g @ joinid2category
rel_counts = counts / counts.sum(axis=1)
out[col] = col_categorical.cat.categories[rel_counts.argmax(axis=1).A.flatten()].astype(object)
out[f"{col}_confidence"] = rel_counts.max(axis=1).toarray().flatten()

return pd.DataFrame.from_dict(out)

0 comments on commit 3070d6a

Please sign in to comment.