Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Ability to search and index documents with other metadata #9

Open
regstuff opened this issue Feb 2, 2023 · 14 comments
Labels
enhancement New feature or request

Comments

@regstuff
Copy link

regstuff commented Feb 2, 2023

Hi,
Nice choice with War Pigs in the example. :)

Been looking for a pure-python based search engine, ever since Whoosh stopped being actively developed.

Realize this library is just getting started out, but was wondering if it is possible to add the ability to search and filter by metadata as well.

For example

collection = [
  {"id": "doc_1", "text": "Generals gathered in their masses", "album": "War pigs"},
  {"id": "doc_2", "text": "Finished with my woman", "album": "Paranoid"}
]

I might want to search for all lines where the album has the word "pigs" for eg.

Also, is the search OR by default, as in find ANY of the words in the query. Can we search with AND and other BOOLEAN operators, as well as proximity and phrase search? Lucene has these features.

Any plan to combine knn search with the text search?

@regstuff regstuff changed the title Ability to search and index documents with other metadata [Feature Request] Ability to search and index documents with other metadata Feb 2, 2023
@AmenRa AmenRa added the enhancement New feature or request label Feb 2, 2023
@AmenRa
Copy link
Owner

AmenRa commented Feb 2, 2023

Hello, dear metalhead :)

I want to add most of the features you mentioned in the next few months.

Currently, I am working on adding ANN search, KNN search, and semantic re-ranking.
I plan to release those features in the next couple of weeks.

Search only works as you described (OR) for now.

I will notify you when new features are released.

@AmenRa
Copy link
Owner

AmenRa commented Feb 19, 2023

Hi @regstuff, I added several new features in v0.2.0 (ANN search, KNN search, hybrid retrieval, ...).

@alex2awesome
Copy link
Contributor

alex2awesome commented Aug 3, 2023

Hey @AmenRa what's the status on filtering by metadata?

I would have used this library in at least 2-3 more papers over the past few months if it had:

  1. Filter by metadata
  2. filter by date range

@AmenRa
Copy link
Owner

AmenRa commented Aug 3, 2023

Hi, sorry for the delay.
Unfortunately, I do not have any update on this.
I'll try to find some time to work on it in the next few days.

@alex2awesome
Copy link
Contributor

alex2awesome commented Aug 3, 2023

Honestly, even just being able to pass in a subset of indexes and getting them scored would be sufficient... i can do the date filtering and metadata filtering on my side.

Instead of now, where I would potentially need to score the whole index in relation to a query term and then filter afterwards, which is wasteful/unworkable...

@alex2awesome
Copy link
Contributor

alex2awesome commented Aug 3, 2023

I have an implementation for this. I'm not 100% familiar with numba and numba's approach to typing so there may be a few errors there, but it works for the toy example:

from typing import List, Union, Tuple
import numba as nb
import numpy.typing as npt
import numpy as np
from numba.typed import List as TypedList
from retriv.sparse_retriever.sparse_retrieval_models.bm25 import bm25
from retriv.paths import sr_state_path
from numba import njit
from retriv.utils.numba_utils import join_sorted_multi_recursive, unsorted_top_k


@njit(cache=True)
def bm25(
    b: float,
    k1: float,
    term_doc_freqs: nb.typed.List[np.ndarray],
    doc_ids: nb.typed.List[np.ndarray],
    filtered_doc_ids: Union[nb.typed.List[np.ndarray], None],
    relative_doc_lens: nb.typed.List[np.ndarray],
    doc_count: int,
    cutoff: int,
) -> Tuple[np.ndarray]:
    if filtered_doc_ids is None:
        unique_doc_ids = join_sorted_multi_recursive(doc_ids)
    else:
        unique_doc_ids = join_sorted_multi_recursive(filtered_doc_ids)

    scores = np.empty(doc_count, dtype=np.float32)
    scores[unique_doc_ids] = 0.0  # Initialize scores

    for i in range(len(term_doc_freqs)):
        if filtered_doc_ids is None:
            indices = doc_ids[i]
        else:
            indices = filtered_doc_ids[i]
        freqs = term_doc_freqs[i]

        df = np.float32(len(doc_ids[i]))
        idf = np.float32(np.log(1.0 + (((doc_count - df) + 0.5) / (df + 0.5))))

        scores[indices] += idf * (
            (freqs * (k1 + 1.0))
            / (freqs + k1 * (1.0 - b + (b * relative_doc_lens[indices])))
        )

    scores = scores[unique_doc_ids]

    if cutoff < len(scores):
        scores, indices = unsorted_top_k(scores, cutoff)
        unique_doc_ids = unique_doc_ids[indices]

    indices = np.argsort(-scores)

    return unique_doc_ids[indices], scores[indices]



class MySparseRetriever(SparseRetriever):
    def __int__(self):
        super().__init__()
        self.id_mapping_reverse = {v:k for k, v in self.id_mapping.items()}

    def get_doc_ids_and_term_freqs_filter(self,
                                          query_terms: List[str],
                                          allowed_list: Union[npt.ArrayLike, None]) -> List[nb.types.List]:
        if allowed_list is None:
            return [self.get_doc_ids(query_terms), None, self.get_term_doc_freqs(query_terms)]
        else:
            doc_ids = [self.inverted_index[t]["doc_ids"] for t in query_terms]
            allowed_ids = [np.in1d(doc_ids_i, allowed_list, assume_unique=True) for doc_ids_i in doc_ids]
            filtered_doc_ids = [full_list[b] for full_list, b in zip(doc_ids, allowed_ids)]
            tfs = [self.inverted_index[t]["tfs"] for t in query_terms]
            tfs = [full_list[b] for full_list, b in zip(tfs, allowed_ids)]
            return [TypedList(doc_ids), TypedList(filtered_doc_ids), TypedList(tfs)]

    def search(self, query: str, return_docs: bool = True, cutoff: int = 100,
               include_id_list: Union[List[str], None] = None,
               exclude_id_list: Union[List[str], None] = None,
               ) -> List:
        """Standard search functionality.

        Args:
            query (str): what to search for.
            return_docs (bool, optional): wether to return the texts of the documents. Defaults to True.
            cutoff (int, optional): number of results to return. Defaults to 100.
            include_id_list (list[str], optional): list of doc_ids to include. Defaults to None.
            exclude_id_list (list[str], optional): list of doc_ids to exclude. Defaults to None.

        Returns:
            List: results.
        """

        query_terms = self.query_preprocessing(query)
        if not query_terms:
            return {}
        query_terms = [t for t in query_terms if t in self.vocabulary]
        if not query_terms:
            return {}

        include_list = None
        if include_id_list is not None:
            include_list = [self.id_mapping_reverse[i] for i in include_id_list]
        if exclude_id_list is not None:
            exclude_id_list = set([self.id_mapping_reverse[i] for i in exclude_id_list])
            if include_list is None:
                include_list = self.id_mapping_reverse.values()
            include_list = [i for i in include_list if i not in exclude_id_list]

        doc_ids, filtered_doc_ids, term_doc_freqs = self.get_doc_ids_and_term_freqs_filter(query_terms, include_list)

        if self.model == "bm25":
            unique_doc_ids, scores = bm25(
                term_doc_freqs=term_doc_freqs,
                doc_ids=doc_ids,
                relative_doc_lens=self.relative_doc_lens,
                doc_count=self.doc_count,
                filtered_doc_ids=filtered_doc_ids,
                cutoff=cutoff,
                **self.hyperparams,
            )
        elif self.model == "tf-idf":
            unique_doc_ids, scores = tf_idf(
                term_doc_freqs=term_doc_freqs,
                doc_ids=doc_ids,
                doc_lens=self.doc_lens,
                cutoff=cutoff,
            )
        else:
            raise NotImplementedError()

        unique_doc_ids = self.map_internal_ids_to_original_ids(unique_doc_ids)

        if not return_docs:
            return dict(zip(unique_doc_ids, scores))

        return self.prepare_results(unique_doc_ids, scores)

    @staticmethod
    def load(index_name: str = "new-index"):
        """Load a retriever and its index.

        Args:
            index_name (str, optional): Name of the index. Defaults to "new-index".

        Returns:
            SparseRetriever: Sparse Retriever.
        """

        state = np.load(sr_state_path(index_name), allow_pickle=True)["state"][()]

        se = MySparseRetriever(**state["init_args"])
        se.initialize_doc_index()
        se.id_mapping = state["id_mapping"]
        se.doc_count = state["doc_count"]
        se.inverted_index = state["inverted_index"]
        se.vocabulary = set(se.inverted_index)
        se.doc_lens = state["doc_lens"]
        se.relative_doc_lens = state["relative_doc_lens"]
        se.hyperparams = state["hyperparams"]
        if 'id_mapping_reverse' not in state:
            se.id_mapping_reverse = {v:k for k, v in se.id_mapping.items()}
        else:
            se.id_mapping_reverse = state['id_mapping_reverse']

        state = {
            "init_args": se.init_args,
            "id_mapping": se.id_mapping,
            "doc_count": se.doc_count,
            "inverted_index": se.inverted_index,
            "vocabulary": se.vocabulary,
            "doc_lens": se.doc_lens,
            "relative_doc_lens": se.relative_doc_lens,
            "hyperparams": se.hyperparams,
            "id_mapping_reverse": se.id_mapping_reverse,
        }

        return se



collection = [
    {"id": "doc_1", "text": "Generals gathered in their masses"},
    {"id": "doc_2", "text": "Just like witches at black masses"},
    {"id": "doc_3", "text": "Evil minds that plot destruction"},
    {"id": "doc_4", "text": "Sorcerer of death's construction"},
]

se = MySparseRetriever.load("new-index")
print(se.search("witches masses"))
print(se.search("witches masses", include_id_list=["doc_2", "doc_3", "doc_4"]))```

prints:

[{'id': 'doc_2', 'text': 'Just like witches at black masses', 'score': 1.792371}, {'id': 'doc_1', 'text': 'Generals gathered in their masses', 'score': 0.7361701}]
[{'id': 'doc_2', 'text': 'Just like witches at black masses', 'score': 1.792371}]

@AmenRa
Copy link
Owner

AmenRa commented Aug 4, 2023

Hi, yes that's one way to do it.
I am starting to work on the requested features.
I will notify you when they are ready.

@alex2awesome
Copy link
Contributor

would you like me to open a PR?

@AmenRa
Copy link
Owner

AmenRa commented Aug 7, 2023

I truly appreciate your help and feedback, but I am building a separate searcher with metadata and doc ids filtering as I prefer to keep it separate from the SparseRetriever for the moment.
Therefore, there's no need for opening a PR right now.
Thanks again!

@AmenRa
Copy link
Owner

AmenRa commented Aug 23, 2023

Hi, I released an experimental retriever with filtering functionalities.
You can read more about it here.
It only supports single query search at the moment. Feedbacks are welcome.

@alex2awesome
Copy link
Contributor

great!! thanks so much man, I'll check it out. Actually just needed this.

Ahhh i see it only supports SparseRetriever right now. I will take a look

@alex2awesome
Copy link
Contributor

alex2awesome commented Aug 23, 2023

I took a look, i haven't fully parsed what you're doing in SparseRetriever since I'm working with a DenseRetriever right now.

I can see why you didn't do the subset for the DenseRetriever, as faiss is super, super gnarly.

After a ton of headbanging and reading faiss source code, I have a recipe that worked in my case:

import faiss
from retriv.dense_retriever.dense_retriever import DenseRetriever
import random

# load index
dr = DenseRetriever.load(index_name='full-index__2023-08-17') ## pretrained index
encoded_query = dr.encoder('test') ## test query
encoded_query = encoded_query.reshape(1, len(encoded_query))

## baseline_query
res_0 = dr.ann_searcher.faiss_index.search(encoded_query, 100) ## results when we don't pass in a subset of the ids

## subset 
ids_list = [
    527938, 447029, 144546, 657363, 278523, 168930, 210273, 643925,
    668591, 113593, 234311, 893935, 233993, 421327, 749351, 282167,
    531984, 234300,  31035, 571879,  29422,  29645, 103771, 791342,
    234419, 560528, 838081
]

## randomly sort them 
ids_list = sorted(ids_list, key=lambda x: random.random()) ## shuffle the ids in the `subset` that we wish to restrict the search to. let's see if faiss search maintains the order of the desired subset
## next three lines are the recipe for how to filter faiss_index to a subset of ids
sel = faiss.IDSelectorBatch(ids_list) 
search_params = faiss.SearchParametersHNSW(sel=sel, efSearch=526)
res_1 = dr.ann_searcher.faiss_index.search(encoded_query, 100, params=search_params)


## randomly sort them again
ids_list = sorted(ids_list, key=lambda x: random.random()) ## shuffle the ids in the `subset` that we wish to restrict the search to.
sel = faiss.IDSelectorBatch(ids_list) # this `id_list` has been shuffled, but we want to test if the faiss ordering is the same.
search_params = faiss.SearchParametersHNSW(sel=sel, efSearch=526)
res_2 = dr.ann_searcher.faiss_index.search(encoded_query, 100, params=search_params)

## sanity check tests
## 1️⃣1️⃣1️⃣1️⃣ test 1: does subsetted FAISS return all the ids in our subset
scores, doc_ids = res_1
doc_ids = list(filter(lambda x: x != -1, doc_ids[0]))  
print(sorted(doc_ids) == sorted(ids_list))
## 2️⃣2️⃣2️⃣2️⃣ test 2: are the two shuffled faiss lists equal to each other?
print((res_1[1] == res_2[1]).all())

## print metadata about our index.
print(open('~/.retriv/collections/full-index__2023-08-17/faiss_index_infos.json').read())

This returns:

## 1️⃣1️⃣1️⃣1️⃣ test 1: does the list of ids retrieved by faiss == the subset we wish to search over? Yes. ✅
>>> True
## 2️⃣2️⃣2️⃣2️⃣ test 2: does shuffling the order order of the subset affect the faiss ordering? No. ✅
>>> True

## index metadata
>>> {"index_key": "HNSW32", "index_param": "efSearch=526", "index_path": "~/.retriv/collections/full-index__2023-08-17/faiss.index", "size in bytes": 3011679438, "avg_search_speed_ms": 9.243549004174742, "99p_search_speed_ms": 19.520642068237102, "reconstruction error %": 0.0, "nb vectors": 900585, "vectors dimension": 768, "compression ratio": 0.9186227076800861}

Note the efSearch=526 parameter and the SearchParametersHNSW, these are specific to the indexing method that autofaiss uses. efSearch is super crucial for datasets with a lot of partitions.

This also changes for different kinds of indexes. For example, if we force autofaiss to make a different kind of index:

from autofaiss import build_index
import numpy as np

## force the index to be not HNSW
embeddings = np.float32(np.random.rand(10000, 512))
index, index_infos = build_index(embeddings, save_on_disk=False, should_be_memory_mappable=True)

## query
q = np.random.rand(1, 512)
full_search_res = index.search(q, k=10)
subset_ids = full_search_res[1][0][3:5].tolist()

# this is how we restrict search to a subset of IDs when `autofaiss` uses the IVF index.
subset_res = index.search(
    q, 
    k=10,
    params=faiss.SearchParametersIVF(sel=faiss.IDSelectorBatch(subset_ids), nprobe=512)
)

# sanity check
### 1️⃣1️⃣1️⃣1️⃣ test 1: does FAISS return search over the all of the correct ids?
print(subset_ids == list(filter(lambda x: x != -1 , subset_res[1][0])))

## here is info about our index
print(index_infos)

This returns:

### 1️⃣1️⃣1️⃣1️⃣ test 1:
True

### info about our index
{'index_key': 'IVF512,Flat', 'index_param': 'nprobe=512', 'index_path': '...knn.index', 'size in bytes': 21612811, 'avg_search_speed_ms': 1.0122442120318738, '99p_search_speed_ms': 1.1052904212556314, 'reconstruction error %': 0.0, 'nb vectors': 10000, 'vectors dimension': 512, 'compression ratio': 0.9475861330578425}

So I don't know enough about these indices and how autofaiss makes decisions to confidently handle every case, but it shouldn't be that hard to propogate up errors.

Anyway, I'm happy to do a PR if this helps.

@AmenRa
Copy link
Owner

AmenRa commented Aug 24, 2023

@alex2awesome can you add comments to your code, please? I do not understand what you are doing / trying to achieve.

@alex2awesome
Copy link
Contributor

ok i updated...

All I was doing was showing a recipe for how to use the faiss API to filter FAISS search to a subset of the IDs, in two cases. The faiss documentation very hard to parse, and not many people have the same use case online, so having this coded up in your library would be a huge value add.

The code I pasted here is showing how to do it, and running some sanity checks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants