diff --git a/docs/user_guide/hybrid_queries_02.ipynb b/docs/user_guide/hybrid_queries_02.ipynb
index c736bab1..d0185760 100644
--- a/docs/user_guide/hybrid_queries_02.ipynb
+++ b/docs/user_guide/hybrid_queries_02.ipynb
@@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Complex Queries\n",
+ "# Query\n",
"\n",
"In this notebook, we will explore more complex queries that can be performed with ``redisvl``\n",
"\n",
@@ -95,8 +95,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "\u001b[32m19:55:11\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
- "\u001b[32m19:55:11\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n"
+ "\u001b[32m17:09:16\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n",
+ "\u001b[32m17:09:16\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_index\n"
]
}
],
@@ -120,7 +120,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Executing Hybrid Queries\n",
+ "## Hybrid Queries\n",
"\n",
"Hybrid queries are queries that combine multiple types of filters. For example, you may want to search for a user that is a certain age, has a certain job, and is within a certain distance of a location. This is a hybrid query that combines numeric, tag, and geographic filters."
]
@@ -544,6 +544,155 @@
"result_print(index.query(v))"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Filter Queries\n",
+ "\n",
+ "In some cases, you may not want to run a vector query, but just use a ``FilterExpression`` similar to a SQL query. The ``FilterQuery`` class enable this functionality. It is similar to the ``VectorQuery`` class but soley takes a ``FilterExpression``."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
user | credit_score | age | job |
---|
derrick | low | 14 | doctor |
taimur | low | 15 | CEO |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from redisvl.query import FilterQuery\n",
+ "\n",
+ "has_low_credit = Tag(\"credit_score\") == \"low\"\n",
+ "\n",
+ "filter_query = FilterQuery(\n",
+ " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"location\"],\n",
+ " filter_expression=has_low_credit\n",
+ ")\n",
+ "\n",
+ "results = index.query(filter_query)\n",
+ "\n",
+ "result_print(results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Range Queries\n",
+ "\n",
+ "Range Queries are a useful method to perform a vector search where only results within a vector ``distance_threshold`` are returned. This enables the user to find all records within their dataset that are similar to a query vector where \"similar\" is defined by a quantitative value."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "vector_distance | user | credit_score | age | job |
---|
0 | john | high | 18 | engineer |
0 | derrick | low | 14 | doctor |
0.109129190445 | tyler | high | 100 | engineer |
0.158809006214 | tim | high | 12 | dermatologist |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from redisvl.query import RangeQuery\n",
+ "\n",
+ "range_query = RangeQuery(\n",
+ " vector=[0.1, 0.1, 0.5],\n",
+ " vector_field_name=\"user_embedding\",\n",
+ " return_fields=[\"user\", \"credit_score\", \"age\", \"job\", \"location\"],\n",
+ " distance_threshold=0.2\n",
+ ")\n",
+ "\n",
+ "# same as the vector query or filter query\n",
+ "results = index.query(range_query)\n",
+ "\n",
+ "result_print(results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can also change the distance threshold of the query object between uses if we like. Here we will set ``distance_threshold==0.1``. This means that the query object will return all matches that are within 0.1 of the query object. This is a small distance, so we expect to get fewer matches than before."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "vector_distance | user | credit_score | age | job |
---|
0 | john | high | 18 | engineer |
0 | derrick | low | 14 | doctor |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "range_query.set_distance_threshold(0.1)\n",
+ "\n",
+ "result_print(index.query(range_query))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Range queries can also be used with filters like any other query type. The following limits the results to only include records with a ``job`` of ``engineer`` while also being within the vector range (aka distance)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "vector_distance | user | credit_score | age | job |
---|
0 | john | high | 18 | engineer |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "is_engineer = Text(\"job\") == \"engineer\"\n",
+ "\n",
+ "range_query.set_filter(is_engineer)\n",
+ "\n",
+ "result_print(index.query(range_query))"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -559,7 +708,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -598,7 +747,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 24,
"metadata": {},
"outputs": [
{
@@ -607,7 +756,7 @@
"'@credit_score:{high}'"
]
},
- "execution_count": 20,
+ "execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@@ -620,17 +769,17 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "{'id': 'v1:dc45946a8bc74f47858617c91d593b43', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n",
- "{'id': 'v1:5c628fdfbba247c6843955de04e3a00c', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n",
- "{'id': 'v1:4f1cb6dd167149d59c9c108e09407fc9', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n",
- "{'id': 'v1:f1720dbeb81c4316bedf21ca60357fdf', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n"
+ "{'id': 'v1:d78adb45342c4404a9c40afd4e65f51b', 'payload': None, 'user': 'john', 'age': '18', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '==\\x00\\x00\\x00?'}\n",
+ "{'id': 'v1:a0a202b6398840c5ab2263b1fd4e704a', 'payload': None, 'user': 'nancy', 'age': '94', 'job': 'doctor', 'credit_score': 'high', 'office_location': '-122.4194,37.7749', 'user_embedding': '333?=\\x00\\x00\\x00?'}\n",
+ "{'id': 'v1:1f3b15dfb4ed490186859c1b2cb3df82', 'payload': None, 'user': 'tyler', 'age': '100', 'job': 'engineer', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '=>\\x00\\x00\\x00?'}\n",
+ "{'id': 'v1:465de540d9d54501b09b8e47a0116620', 'payload': None, 'user': 'tim', 'age': '12', 'job': 'dermatologist', 'credit_score': 'high', 'office_location': '-122.0839,37.3861', 'user_embedding': '>>\\x00\\x00\\x00?'}\n"
]
}
],
@@ -653,7 +802,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
@@ -662,7 +811,7 @@
"'((@credit_score:{high} @age:[18 +inf]) @age:[-inf 100])=>[KNN 10 @user_embedding $vector AS vector_distance] RETURN 6 user credit_score age job office_location vector_distance SORTBY vector_distance ASC DIALECT 2 LIMIT 0 10'"
]
},
- "execution_count": 22,
+ "execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py
index 2a6cce6f..19dd7f60 100644
--- a/redisvl/query/__init__.py
+++ b/redisvl/query/__init__.py
@@ -1,6 +1,3 @@
-from redisvl.query.query import FilterQuery, VectorQuery
+from redisvl.query.query import FilterQuery, VectorQuery, RangeQuery
-__all__ = [
- "VectorQuery",
- "FilterQuery",
-]
+__all__ = ["VectorQuery", "FilterQuery", "RangeQuery"]
diff --git a/redisvl/query/query.py b/redisvl/query/query.py
index 03163f26..87b23a87 100644
--- a/redisvl/query/query.py
+++ b/redisvl/query/query.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np
from redis.commands.search.query import Query
@@ -105,13 +105,13 @@ def params(self) -> Dict[str, Any]:
return self._params
-class VectorQuery(BaseQuery):
+class BaseVectorQuery(BaseQuery):
dtypes = {
"float32": np.float32,
"float64": np.float64,
}
-
DISTANCE_ID = "vector_distance"
+ VECTOR_PARAM = "vector"
def __init__(
self,
@@ -123,21 +123,6 @@ def __init__(
num_results: Optional[int] = 10,
return_score: bool = True,
):
- """Query for vector fields
-
- Args:
- vector (List[float]): The vector to query for.
- vector_field_name (str): The name of the vector field
- return_fields (List[str]): The fields to return.
- filter_expression (FilterExpression, optional): A filter to apply to the query. Defaults to None.
- dtype (str, optional): The dtype of the vector. Defaults to "float32".
- num_results (Optional[int], optional): The number of results to return. Defaults to 10.
- return_score (bool, optional): Whether to return the score. Defaults to True.
-
- Raises:
- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
-
- """
super().__init__(return_fields, num_results)
self._vector = vector
self._field = vector_field_name
@@ -173,6 +158,45 @@ def get_filter(self) -> FilterExpression:
def __str__(self):
return " ".join([str(x) for x in self.query.get_args()])
+
+class VectorQuery(BaseVectorQuery):
+ def __init__(
+ self,
+ vector: List[float],
+ vector_field_name: str,
+ return_fields: List[str],
+ filter_expression: FilterExpression = None,
+ dtype: str = "float32",
+ num_results: Optional[int] = 10,
+ return_score: bool = True,
+ ):
+ """Query for vector fields.
+
+ Read more: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
+
+ Args:
+ vector (List[float]): The vector to query for.
+ vector_field_name (str): The name of the vector field.
+ return_fields (List[str]): The fields to return.
+ filter_expression (FilterExpression, optional): A filter to apply to the query. Defaults to None.
+ dtype (str, optional): The dtype of the vector. Defaults to "float32".
+ num_results (Optional[int], optional): The number of results to return. Defaults to 10.
+ return_score (bool, optional): Whether to return the score. Defaults to True.
+
+ Raises:
+ TypeError: If filter_expression is not of type redisvl.query.FilterExpression
+
+ """
+ super().__init__(
+ vector,
+ vector_field_name,
+ return_fields,
+ filter_expression,
+ dtype,
+ num_results,
+ return_score,
+ )
+
@property
def query(self) -> Query:
"""Return a Redis-Py Query object representing the query.
@@ -180,7 +204,7 @@ def query(self) -> Query:
Returns:
redis.commands.search.query.Query: The query object.
"""
- base_query = f"{self._filter}=>[KNN {self._num_results} @{self._field} $vector AS {self.DISTANCE_ID}]"
+ base_query = f"{self._filter}=>[KNN {self._num_results} @{self._field} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
query = (
Query(base_query)
.return_fields(*self._return_fields)
@@ -197,4 +221,119 @@ def params(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The parameters for the query.
"""
- return {"vector": array_to_buffer(self._vector, dtype=self.dtypes[self._dtype])}
+ return {
+ self.VECTOR_PARAM: array_to_buffer(
+ self._vector, dtype=self.dtypes[self._dtype]
+ )
+ }
+
+
+class RangeQuery(BaseVectorQuery):
+ DISTANCE_THRESHOLD_PARAM = "distance_threshold"
+
+ def __init__(
+ self,
+ vector: List[float],
+ vector_field_name: str,
+ return_fields: List[str],
+ filter_expression: FilterExpression = None,
+ dtype: str = "float32",
+ distance_threshold: float = 0.2,
+ num_results: Optional[int] = None,
+ return_score: bool = True,
+ ):
+ """Vector query by distance range.
+
+ Range queries are for filtering vector search results
+ by the distance between a vector field value and a query
+ vector, in terms of the index distance metric.
+
+ Read more: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query
+
+ Args:
+ vector (List[float]): The vector to query for.
+ vector_field_name (str): The name of the vector field.
+ return_fields (List[str]): The fields to return.
+ filter_expression (FilterExpression, optional): A filter to apply to the query. Defaults to None.
+ dtype (str, optional): The dtype of the vector. Defaults to "float32".
+ distance_threshold (str, float): The threshold for vector distance. Defaults to 0.2.
+ num_results (Optional[int], optional): The MAX number of results to return. Defaults to None.
+ return_score (bool, optional): Whether to return the score. Defaults to True.
+
+ Raises:
+ TypeError: If filter_expression is not of type redisvl.query.FilterExpression
+
+ """
+ super().__init__(
+ vector,
+ vector_field_name,
+ return_fields,
+ filter_expression,
+ dtype,
+ num_results,
+ return_score,
+ )
+ self.set_distance_threshold(distance_threshold)
+
+ def set_distance_threshold(self, distance_threshold: float):
+ """_summary_
+
+ Args:
+ distance_threshold (float): _description_
+ """
+ if not isinstance(distance_threshold, (float, int)):
+ raise TypeError("distance_threshold must be of type int or float")
+ self._distance_threshold = distance_threshold
+
+ @property
+ def distance_threshold(self) -> float:
+ """Return the distance threshold for the query.
+
+ Returns:
+ float: The distance threshold for the query.
+ """
+ return self._distance_threshold
+
+ @property
+ def query(self) -> Query:
+ """Return a Redis-Py Query object representing the query.
+
+ Returns:
+ redis.commands.search.query.Query: The query object.
+ """
+ base_query = f"@{self._field}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]"
+
+ if self._filter != "*":
+ base_query = (
+ "("
+ + base_query
+ + f"=>{{$yield_distance_as: {self.DISTANCE_ID}}} "
+ + self._filter
+ + ")"
+ )
+ else:
+ base_query += f"=>{{$yield_distance_as: {self.DISTANCE_ID}}}"
+
+ query = (
+ Query(base_query)
+ .return_fields(*self._return_fields)
+ .sort_by(self.DISTANCE_ID)
+ .dialect(2)
+ )
+ if self._num_results:
+ query.paging(0, self._num_results)
+ return query
+
+ @property
+ def params(self) -> Dict[str, Any]:
+ """Return the parameters for the query.
+
+ Returns:
+ Dict[str, Any]: The parameters for the query.
+ """
+ return {
+ self.VECTOR_PARAM: array_to_buffer(
+ self._vector, dtype=self.dtypes[self._dtype]
+ ),
+ self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold,
+ }
diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py
index 4ea04de5..f44ceb8d 100644
--- a/tests/integration/test_query.py
+++ b/tests/integration/test_query.py
@@ -5,7 +5,7 @@
from redis.commands.search.result import Result
from redisvl.index import SearchIndex
-from redisvl.query import FilterQuery, VectorQuery
+from redisvl.query import FilterQuery, VectorQuery, RangeQuery
from redisvl.query.filter import Geo, GeoRadius, Num, Tag, Text
data = [
@@ -91,19 +91,6 @@
}
-vector_query = VectorQuery(
- [0.1, 0.1, 0.5],
- "user_embedding",
- return_fields=["user", "credit_score", "age", "job", "location"],
-)
-
-filter_query = FilterQuery(
- return_fields=["user", "credit_score", "age", "job", "location"],
- # this will get set everytime
- filter_expression=Tag("credit_score") == "high",
-)
-
-
@pytest.fixture(scope="module")
def index():
# construct a search index from the schema
@@ -142,7 +129,7 @@ def test_simple(index):
assert doc.credit_score in ["high", "low", "medium"]
-def test_search_qeury(index):
+def test_search_query(index):
# *=>[KNN 7 @user_embedding $vector AS vector_distance]
v = VectorQuery(
[0.1, 0.1, 0.5],
@@ -162,23 +149,46 @@ def test_search_qeury(index):
assert processed_results[0] == results.docs[0].__dict__
-def test_simple_tag_filter(index):
- # (@credit_score:{high})=>[KNN 10 @user_embedding $vector AS vector_distance]
- t = Tag("credit_score") == "high"
- v = VectorQuery(
+def test_range_query(index):
+ r = RangeQuery(
[0.1, 0.1, 0.5],
"user_embedding",
return_fields=["user", "credit_score", "age", "job"],
- filter_expression=t,
+ distance_threshold=0.2,
+ num_results=7,
)
+ results = index.query(r)
+ for result in results:
+ assert float(result["vector_distance"]) <= 0.2
+ assert len(results) == 4
+ assert r.distance_threshold == 0.2
- results = index.search(v.query, query_params=v.params)
- assert len(results.docs) == 4
+ r.set_distance_threshold(0.1)
+ assert r.distance_threshold == 0.1
+ results = index.query(r)
+ for result in results:
+ assert float(result["vector_distance"]) <= 0.1
+ assert len(results) == 2
-@pytest.fixture(params=[vector_query, filter_query], ids=["VectorQuery", "FilterQuery"])
-def query(request):
- return request.param
+vector_query = VectorQuery(
+ vector=[0.1, 0.1, 0.5],
+ vector_field_name="user_embedding",
+ return_fields=["user", "credit_score", "age", "job", "location"],
+)
+
+filter_query = FilterQuery(
+ return_fields=["user", "credit_score", "age", "job", "location"],
+ # this will get set everytime
+ filter_expression=Tag("credit_score") == "high",
+)
+
+range_query = RangeQuery(
+ vector=[0.1, 0.1, 0.5],
+ vector_field_name="user_embedding",
+ return_fields=["user", "credit_score", "age", "job", "location"],
+ distance_threshold=0.2,
+)
def filter_test(
@@ -189,17 +199,23 @@ def filter_test(
credit_check=None,
age_range=None,
location=None,
+ distance_threshold=0.2,
):
"""Utility function to test filters"""
# set the new filter
query.set_filter(_filter)
+ print(str(query))
# print(str(v) + "\n") # to print the query
results = index.search(query.query, query_params=query.params)
+
+ # check for tag filter correctness
if credit_check:
for doc in results.docs:
assert doc.credit_score == credit_check
+
+ # check for numeric filter correctness
if age_range:
for doc in results.docs:
if len(age_range) == 3:
@@ -208,10 +224,29 @@ def filter_test(
assert (int(doc.age) <= age_range[0]) or (int(doc.age) >= age_range[1])
else:
assert age_range[0] <= int(doc.age) <= age_range[1]
+
+ # check for geographic filter correctness
if location:
for doc in results.docs:
assert doc.location == location
- assert len(results.docs) == expected_count
+
+ # if range query, test results by distance threshold
+ if isinstance(query, RangeQuery):
+ for doc in results.docs:
+ print(doc.vector_distance)
+ assert float(doc.vector_distance) <= distance_threshold
+
+ # otherwise check by expected count.
+ else:
+ assert len(results.docs) == expected_count
+
+
+@pytest.fixture(
+ params=[vector_query, filter_query, range_query],
+ ids=["VectorQuery", "FilterQuery", "RangeQuery"],
+)
+def query(request):
+ return request.param
def test_filters(index, query):