Skip to content

Commit

Permalink
fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jul 16, 2024
1 parent b35048c commit eaa8868
Showing 1 changed file with 6 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Nearest-neighbor search based on vector index of Census embeddings."""

from contextlib import ExitStack
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, cast
from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple, cast

import anndata as ad
import numpy as np
import numpy.typing as npt
import pandas as pd
import tiledb.vector_search as vs
import tiledbsoma as soma
from scipy import sparse

from .._experiment import _get_experiment_name
from .._open import DEFAULT_TILEDB_CONFIGURATION, open_soma
Expand Down Expand Up @@ -157,7 +158,7 @@ def predict_obs_metadata(
# found among its neighbors, with a confidence score based on the simple fraction (for now)
# TODO: something more intelligent for numeric columns! also use distances, etc.
max_joinid = neighbor_obs.index.max()
out: dict[str, pd.Series] = {}
out: dict[str, pd.Series[Any]] = {}
indices = np.broadcast_to(np.arange(neighbors.neighbor_ids.shape[0]), (10, neighbors.neighbor_ids.shape[0])).T
g = sparse.csr_matrix(
(
Expand All @@ -167,19 +168,13 @@ def predict_obs_metadata(
neighbors.neighbor_ids.astype(np.int64).flatten(),
),
),
shape=(neighbors.neighbor_ids.shape[0], max_joinid + 1)
shape=(neighbors.neighbor_ids.shape[0], max_joinid + 1),
)
for col in column_names:
col_categorical = neighbor_obs[col].astype("category")
joinid2category = sparse.coo_matrix(
(
np.broadcast_to(1, len(neighbor_obs)),
(
neighbor_obs.index,
col_categorical.cat.codes
)
),
shape=(max_joinid + 1, len(col_categorical.cat.categories))
(np.broadcast_to(1, len(neighbor_obs)), (neighbor_obs.index, col_categorical.cat.codes)),
shape=(max_joinid + 1, len(col_categorical.cat.categories)),
)
counts = g @ joinid2category
rel_counts = counts / counts.sum(axis=1)
Expand Down

0 comments on commit eaa8868

Please sign in to comment.