diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py index de09e2060..3fc628e78 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/_embedding_search.py @@ -160,16 +160,17 @@ def predict_obs_metadata( # TODO: something more intelligent for numeric columns! also use distances, etc. max_joinid = neighbor_obs.index.max() out: dict[str, pd.Series[Any]] = {} - indices = np.broadcast_to(np.arange(neighbors.neighbor_ids.shape[0]), (10, neighbors.neighbor_ids.shape[0])).T + n_queries, n_neighbors = neighbors.neighbor_ids.shape + indices = np.broadcast_to(np.arange(n_queries), (n_neighbors, n_queries)).T g = sparse.csr_matrix( ( - np.broadcast_to(1, neighbors.neighbor_ids.shape[0] * 10), + np.broadcast_to(1, n_queries * n_neighbors), ( indices.flatten(), neighbors.neighbor_ids.astype(np.int64).flatten(), ), ), - shape=(neighbors.neighbor_ids.shape[0], max_joinid + 1), + shape=(n_queries, max_joinid + 1), ) for col in column_names: col_categorical = neighbor_obs[col].astype("category") diff --git a/api/python/cellxgene_census/tests/experimental/test_embeddings_search.py b/api/python/cellxgene_census/tests/experimental/test_embeddings_search.py index a58a31628..9e1c3da0a 100644 --- a/api/python/cellxgene_census/tests/experimental/test_embeddings_search.py +++ b/api/python/cellxgene_census/tests/experimental/test_embeddings_search.py @@ -39,6 +39,26 @@ def test_embeddings_search(true_neighbors: dict[str, Any], query_result: Neighbo return +@pytest.mark.experimental +@pytest.mark.live_corpus +@pytest.mark.parametrize("n_neighbors", [5, 7, 20]) +def test_embedding_search_n_neighbors(query_anndata: ad.AnnData, n_neighbors: int) -> None: + columns = ["cell_type"] + result = find_nearest_obs( + TRUE_NEAREST_NEIGHBORS_EMBEDDING_NAME, + TRUE_NEAREST_NEIGHBORS_ORGANISM, + TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, + query_anndata, + k=n_neighbors, + nprobe=25, + ) + + # Check that the correct number of neighbors is being returned + assert result.neighbor_ids.shape[1] == n_neighbors + # Check that this step works + _ = predict_obs_metadata(TRUE_NEAREST_NEIGHBORS_ORGANISM, TRUE_NEAREST_NEIGHBORS_CENSUS_VERSION, result, columns) + + @pytest.mark.experimental @pytest.mark.live_corpus def test_embeddings_search_errors(query_anndata: ad.AnnData) -> None: