Skip to content

Commit

Permalink
Fix (#1277)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup authored Sep 10, 2024
1 parent fa40a6a commit 3ef4270
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,17 @@ def predict_obs_metadata(
# TODO: something more intelligent for numeric columns! also use distances, etc.
max_joinid = neighbor_obs.index.max()
out: dict[str, pd.Series[Any]] = {}
indices = np.broadcast_to(np.arange(neighbors.neighbor_ids.shape[0]), (10, neighbors.neighbor_ids.shape[0])).T
n_queries, n_neighbors = neighbors.neighbor_ids.shape
indices = np.broadcast_to(np.arange(n_queries), (n_neighbors, n_queries)).T
g = sparse.csr_matrix(
(
np.broadcast_to(1, neighbors.neighbor_ids.shape[0] * 10),
np.broadcast_to(1, n_queries * n_neighbors),
(
indices.flatten(),
neighbors.neighbor_ids.astype(np.int64).flatten(),
),
),
shape=(neighbors.neighbor_ids.shape[0], max_joinid + 1),
shape=(n_queries, max_joinid + 1),
)
for col in column_names:
col_categorical = neighbor_obs[col].astype("category")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ def test_embeddings_search(true_neighbors: dict[str, Any], query_result: Neighbo
return


@pytest.mark.experimental
@pytest.mark.live_corpus
@pytest.mark.parametrize("n_neighbors", [5, 7, 20])
def test_embedding_search_n_neighbors(query_anndata: ad.AnnData, n_neighbors: int) -> None:
columns = ["cell_type"]
result = find_nearest_obs(
TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME,
TRUE_NEAREST_NEIGHBORS_ORGANISM,
TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION,
query_anndata,
k=n_neighbors,
nprobe=25,
)

# Check that the correct number of neighbors is being returned
assert result.neighbor_ids.shape[1] == n_neighbors
# Check that this step works
_ = predict_obs_metadata(TRUE_NEAREST_NEIGHBORS_ORGANISM, TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, result, columns)


@pytest.mark.experimental
@pytest.mark.live_corpus
def test_embeddings_search_errors(query_anndata: ad.AnnData) -> None:
Expand Down

0 comments on commit 3ef4270

Please sign in to comment.