diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb index c736bab1..d0185760 100644 --- a/docs/user_guide/hybrid_queries_02.ipynb +++ b/docs/user_guide/hybrid_queries_02.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Complex Queries\n", + "# Query\n", "\n", "In this notebook, we will explore more complex queries that can be performed with ``redisvl``\n", "\n", @@ -95,8 +95,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m19:55:11\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m19:55:11\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" + "\u001b[32m17:09:16\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m17:09:16\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n" ] } ], @@ -120,7 +120,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Executing Hybrid Queries\n", + "## Hybrid Queries\n", "\n", "Hybrid queries are queries that combine multiple types of filters. For example, you may want to search for a user that is a certain age, has a certain job, and is within a certain distance of a location. This is a hybrid query that combines numeric, tag, and geographic filters." ] @@ -544,6 +544,155 @@ "result_print(index.query(v))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Filter Queries\n", + "\n", + "In some cases, you may not want to run a vector query, but just use a ``FilterExpression`` similar to a SQL query. The ``FilterQuery`` class enable this functionality. It is similar to the ``VectorQuery`` class but soley takes a ``FilterExpression``." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
usercredit_scoreagejob
derricklow14doctor
taimurlow15CEO
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query import FilterQuery\n", + "\n", + "has_low_credit = Tag(\"credit_score\") == \"low\"\n", + "\n", + "filter_query = FilterQuery(\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"location\"],\n", + " filter_expression=has_low_credit\n", + ")\n", + "\n", + "results = index.query(filter_query)\n", + "\n", + "result_print(results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Range Queries\n", + "\n", + "Range Queries are a useful method to perform a vector search where only results within a vector ``distance_threshold`` are returned. This enables the user to find all records within their dataset that are similar to a query vector where \"similar\" is defined by a quantitative value." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
0.109129190445tylerhigh100engineer
0.158809006214timhigh12dermatologist
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from redisvl.query import RangeQuery\n", + "\n", + "range_query = RangeQuery(\n", + " vector=[0.1, 0.1, 0.5],\n", + " vector_field_name=\"user_embedding\",\n", + " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"location\"],\n", + " distance_threshold=0.2\n", + ")\n", + "\n", + "# same as the vector query or filter query\n", + "results = index.query(range_query)\n", + "\n", + "result_print(results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also change the distance threshold of the query object between uses if we like. Here we will set ``distance_threshold==0.1``. This means that the query object will return all matches that are within 0.1 of the query object. This is a small distance, so we expect to get fewer matches than before." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
0derricklow14doctor
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "range_query.set_distance_threshold(0.1)\n", + "\n", + "result_print(index.query(range_query))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Range queries can also be used with filters like any other query type. The following limits the results to only include records with a ``job`` of ``engineer`` while also being within the vector range (aka distance)." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
vector_distanceusercredit_scoreagejob
0johnhigh18engineer
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "is_engineer = Text(\"job\") == \"engineer\"\n", + "\n", + "range_query.set_filter(is_engineer)\n", + "\n", + "result_print(index.query(range_query))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -559,7 +708,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -598,7 +747,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -607,7 +756,7 @@ "'@credit_score:{high}'" ] }, - "execution_count": 20, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -620,17 +769,17 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'id': 'v1:dc45946a8bc74f47858617c91d593b43', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", - "{'id': 'v1:5c628fdfbba247c6843955de04e3a00c', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", - "{'id': 'v1:4f1cb6dd167149d59c9c108e09407fc9', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", - "{'id': 'v1:f1720dbeb81c4316bedf21ca60357fdf', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" + "{'id': 'v1:d78adb45342c4404a9c40afd4e65f51b', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n", + "{'id': 'v1:a0a202b6398840c5ab2263b1fd4e704a', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n", + "{'id': 'v1:1f3b15dfb4ed490186859c1b2cb3df82', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n", + "{'id': 'v1:465de540d9d54501b09b8e47a0116620', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n" ] } ], @@ -653,7 +802,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -662,7 +811,7 @@ "'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10'" ] }, - "execution_count": 22, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index 2a6cce6f..19dd7f60 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -1,6 +1,3 @@ -from redisvl.query.query import FilterQuery, VectorQuery +from redisvl.query.query import FilterQuery, VectorQuery, RangeQuery -__all__ = [ - "VectorQuery", - "FilterQuery", -] +__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"] diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 03163f26..87b23a87 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np from redis.commands.search.query import Query @@ -105,13 +105,13 @@ def params(self) -> Dict[str, Any]: return self._params -class VectorQuery(BaseQuery): +class BaseVectorQuery(BaseQuery): dtypes = { "float32": np.float32, "float64": np.float64, } - DISTANCE_ID = "vector_distance" + VECTOR_PARAM = "vector" def __init__( self, @@ -123,21 +123,6 @@ def __init__( num_results: Optional[int] = 10, return_score: bool = True, ): - """Query for vector fields - - Args: - vector (List[float]): The vector to query for. - vector_field_name (str): The name of the vector field - return_fields (List[str]): The fields to return. - filter_expression (FilterExpression, optional): A filter to apply to the query. Defaults to None. - dtype (str, optional): The dtype of the vector. Defaults to "float32". - num_results (Optional[int], optional): The number of results to return. Defaults to 10. - return_score (bool, optional): Whether to return the score. Defaults to True. - - Raises: - TypeError: If filter_expression is not of type redisvl.query.FilterExpression - - """ super().__init__(return_fields, num_results) self._vector = vector self._field = vector_field_name @@ -173,6 +158,45 @@ def get_filter(self) -> FilterExpression: def __str__(self): return " ".join([str(x) for x in self.query.get_args()]) + +class VectorQuery(BaseVectorQuery): + def __init__( + self, + vector: List[float], + vector_field_name: str, + return_fields: List[str], + filter_expression: FilterExpression = None, + dtype: str = "float32", + num_results: Optional[int] = 10, + return_score: bool = True, + ): + """Query for vector fields. + + Read more: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search + + Args: + vector (List[float]): The vector to query for. + vector_field_name (str): The name of the vector field. + return_fields (List[str]): The fields to return. + filter_expression (FilterExpression, optional): A filter to apply to the query. Defaults to None. + dtype (str, optional): The dtype of the vector. Defaults to "float32". + num_results (Optional[int], optional): The number of results to return. Defaults to 10. + return_score (bool, optional): Whether to return the score. Defaults to True. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + + """ + super().__init__( + vector, + vector_field_name, + return_fields, + filter_expression, + dtype, + num_results, + return_score, + ) + @property def query(self) -> Query: """Return a Redis-Py Query object representing the query. @@ -180,7 +204,7 @@ def query(self) -> Query: Returns: redis.commands.search.query.Query: The query object. """ - base_query = f"{self._filter}=>[KNN {self._num_results} @{self._field} $vector AS {self.DISTANCE_ID}]" + base_query = f"{self._filter}=>[KNN {self._num_results} @{self._field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" query = ( Query(base_query) .return_fields(*self._return_fields) @@ -197,4 +221,119 @@ def params(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The parameters for the query. """ - return {"vector": array_to_buffer(self._vector, dtype=self.dtypes[self._dtype])} + return { + self.VECTOR_PARAM: array_to_buffer( + self._vector, dtype=self.dtypes[self._dtype] + ) + } + + +class RangeQuery(BaseVectorQuery): + DISTANCE_THRESHOLD_PARAM = "distance_threshold" + + def __init__( + self, + vector: List[float], + vector_field_name: str, + return_fields: List[str], + filter_expression: FilterExpression = None, + dtype: str = "float32", + distance_threshold: float = 0.2, + num_results: Optional[int] = None, + return_score: bool = True, + ): + """Vector query by distance range. + + Range queries are for filtering vector search results + by the distance between a vector field value and a query + vector, in terms of the index distance metric. + + Read more: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query + + Args: + vector (List[float]): The vector to query for. + vector_field_name (str): The name of the vector field. + return_fields (List[str]): The fields to return. + filter_expression (FilterExpression, optional): A filter to apply to the query. Defaults to None. + dtype (str, optional): The dtype of the vector. Defaults to "float32". + distance_threshold (str, float): The threshold for vector distance. Defaults to 0.2. + num_results (Optional[int], optional): The MAX number of results to return. Defaults to None. + return_score (bool, optional): Whether to return the score. Defaults to True. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + + """ + super().__init__( + vector, + vector_field_name, + return_fields, + filter_expression, + dtype, + num_results, + return_score, + ) + self.set_distance_threshold(distance_threshold) + + def set_distance_threshold(self, distance_threshold: float): + """_summary_ + + Args: + distance_threshold (float): _description_ + """ + if not isinstance(distance_threshold, (float, int)): + raise TypeError("distance_threshold must be of type int or float") + self._distance_threshold = distance_threshold + + @property + def distance_threshold(self) -> float: + """Return the distance threshold for the query. + + Returns: + float: The distance threshold for the query. + """ + return self._distance_threshold + + @property + def query(self) -> Query: + """Return a Redis-Py Query object representing the query. + + Returns: + redis.commands.search.query.Query: The query object. + """ + base_query = f"@{self._field}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" + + if self._filter != "*": + base_query = ( + "(" + + base_query + + f"=>{{$yield_distance_as: {self.DISTANCE_ID}}} " + + self._filter + + ")" + ) + else: + base_query += f"=>{{$yield_distance_as: {self.DISTANCE_ID}}}" + + query = ( + Query(base_query) + .return_fields(*self._return_fields) + .sort_by(self.DISTANCE_ID) + .dialect(2) + ) + if self._num_results: + query.paging(0, self._num_results) + return query + + @property + def params(self) -> Dict[str, Any]: + """Return the parameters for the query. + + Returns: + Dict[str, Any]: The parameters for the query. + """ + return { + self.VECTOR_PARAM: array_to_buffer( + self._vector, dtype=self.dtypes[self._dtype] + ), + self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold, + } diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 4ea04de5..f44ceb8d 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -5,7 +5,7 @@ from redis.commands.search.result import Result from redisvl.index import SearchIndex -from redisvl.query import FilterQuery, VectorQuery +from redisvl.query import FilterQuery, VectorQuery, RangeQuery from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text data = [ @@ -91,19 +91,6 @@ } -vector_query = VectorQuery( - [0.1, 0.1, 0.5], - "user_embedding", - return_fields=["user", "credit_score", "age", "job", "location"], -) - -filter_query = FilterQuery( - return_fields=["user", "credit_score", "age", "job", "location"], - # this will get set everytime - filter_expression=Tag("credit_score") == "high", -) - - @pytest.fixture(scope="module") def index(): # construct a search index from the schema @@ -142,7 +129,7 @@ def test_simple(index): assert doc.credit_score in ["high", "low", "medium"] -def test_search_qeury(index): +def test_search_query(index): # *=>[KNN 7 @user_embedding $vector AS vector_distance] v = VectorQuery( [0.1, 0.1, 0.5], @@ -162,23 +149,46 @@ def test_search_qeury(index): assert processed_results[0] == results.docs[0].__dict__ -def test_simple_tag_filter(index): - # (@credit_score:{high})=>[KNN 10 @user_embedding $vector AS vector_distance] - t = Tag("credit_score") == "high" - v = VectorQuery( +def test_range_query(index): + r = RangeQuery( [0.1, 0.1, 0.5], "user_embedding", return_fields=["user", "credit_score", "age", "job"], - filter_expression=t, + distance_threshold=0.2, + num_results=7, ) + results = index.query(r) + for result in results: + assert float(result["vector_distance"]) <= 0.2 + assert len(results) == 4 + assert r.distance_threshold == 0.2 - results = index.search(v.query, query_params=v.params) - assert len(results.docs) == 4 + r.set_distance_threshold(0.1) + assert r.distance_threshold == 0.1 + results = index.query(r) + for result in results: + assert float(result["vector_distance"]) <= 0.1 + assert len(results) == 2 -@pytest.fixture(params=[vector_query, filter_query], ids=["VectorQuery", "FilterQuery"]) -def query(request): - return request.param +vector_query = VectorQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], +) + +filter_query = FilterQuery( + return_fields=["user", "credit_score", "age", "job", "location"], + # this will get set everytime + filter_expression=Tag("credit_score") == "high", +) + +range_query = RangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], + distance_threshold=0.2, +) def filter_test( @@ -189,17 +199,23 @@ def filter_test( credit_check=None, age_range=None, location=None, + distance_threshold=0.2, ): """Utility function to test filters""" # set the new filter query.set_filter(_filter) + print(str(query)) # print(str(v) + "\n") # to print the query results = index.search(query.query, query_params=query.params) + + # check for tag filter correctness if credit_check: for doc in results.docs: assert doc.credit_score == credit_check + + # check for numeric filter correctness if age_range: for doc in results.docs: if len(age_range) == 3: @@ -208,10 +224,29 @@ def filter_test( assert (int(doc.age) <= age_range[0]) or (int(doc.age) >= age_range[1]) else: assert age_range[0] <= int(doc.age) <= age_range[1] + + # check for geographic filter correctness if location: for doc in results.docs: assert doc.location == location - assert len(results.docs) == expected_count + + # if range query, test results by distance threshold + if isinstance(query, RangeQuery): + for doc in results.docs: + print(doc.vector_distance) + assert float(doc.vector_distance) <= distance_threshold + + # otherwise check by expected count. + else: + assert len(results.docs) == expected_count + + +@pytest.fixture( + params=[vector_query, filter_query, range_query], + ids=["VectorQuery", "FilterQuery", "RangeQuery"], +) +def query(request): + return request.param def test_filters(index, query):