Skip to content

Commit

Permalink
Adds support for new vector data types (#222)
Browse files Browse the repository at this point in the history
Adds support for bfloat16, float16, and float64 data types.

Users can now specify the desired dtype for their schema when creating
semantic caches, semantic routers, and semantic session managers.

When not specified data type still defaults to float32.
  • Loading branch information
justin-cechmanek authored Oct 2, 2024
1 parent f808b80 commit 42e33a3
Show file tree
Hide file tree
Showing 33 changed files with 328 additions and 98 deletions.
2 changes: 1 addition & 1 deletion docs/api/schema.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Each field type supports specific attributes that customize its behavior. Below

- `dims`: Dimensionality of the vector.
- `algorithm`: Indexing algorithm (`flat` or `hnsw`).
- `datatype`: Float datatype of the vector (`float32` or `float64`).
- `datatype`: Float datatype of the vector (`bfloat16`, `float16`, `float32`, `float64`).
- `distance_metric`: Metric for measuring query relevance (`COSINE`, `L2`, `IP`).

**HNSW Vector Field Specific Attributes**:
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/openai_qna.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@
"api_key = os.getenv(\"OPENAI_API_KEY\") or getpass.getpass(\"Enter your OpenAI API key: \")\n",
"oaip = OpenAITextVectorizer(EMBEDDINGS_MODEL, api_config={\"api_key\": api_key})\n",
"\n",
"chunked_data[\"embedding\"] = oaip.embed_many(chunked_data[\"content\"].tolist(), as_buffer=True)\n",
"chunked_data[\"embedding\"] = oaip.embed_many(chunked_data[\"content\"].tolist(), as_buffer=True, dtype=\"float32\")\n",
"chunked_data"
]
},
Expand Down Expand Up @@ -1073,7 +1073,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.12.2"
},
"orig_nbformat": 4
},
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/hash_vs_json_05.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@
"json_data = data.copy()\n",
"\n",
"for d in json_data:\n",
" d['user_embedding'] = buffer_to_array(d['user_embedding'], dtype=np.float32)"
" d['user_embedding'] = buffer_to_array(d['user_embedding'], dtype='float32')"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/user_guide/vectorizers_04.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@
"outputs": [],
"source": [
"# You can also create many embeddings at once\n",
"embeddings = hf.embed_many(sentences, as_buffer=True)\n"
"embeddings = hf.embed_many(sentences, as_buffer=True, dtype=\"float32\")\n"
]
},
{
Expand Down Expand Up @@ -569,7 +569,7 @@
"source": [
"from redisvl.utils.vectorize import CustomTextVectorizer\n",
"\n",
"def generate_embeddings(text_input):\n",
"def generate_embeddings(text_input, **kwargs):\n",
" return [0.101] * 768\n",
"\n",
"custom_vectorizer = CustomTextVectorizer(generate_embeddings)\n",
Expand Down
45 changes: 41 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ redis = ">=5.0.0"
pydantic = { version = ">=2,<3" }
tenacity = ">=8.2.2"
tabulate = { version = ">=0.9.0,<1" }
ml-dtypes = "^0.4.0"
openai = { version = ">=1.13.0", optional = true }
sentence-transformers = { version = ">=2.2.2", optional = true }
google-cloud-aiplatform = { version = ">=1.26", optional = true }
Expand Down
8 changes: 4 additions & 4 deletions redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def non_empty_metadata(cls, v):
raise TypeError("Metadata must be a dictionary.")
return v

def to_dict(self) -> Dict:
def to_dict(self, dtype: str) -> Dict:
data = self.dict(exclude_none=True)
data["prompt_vector"] = array_to_buffer(self.prompt_vector)
data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype)
if self.metadata is not None:
data["metadata"] = serialize(self.metadata)
if self.filters is not None:
Expand Down Expand Up @@ -112,7 +112,7 @@ def to_dict(self) -> Dict:
class SemanticCacheIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str, vector_dims: int):
def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
Expand All @@ -126,7 +126,7 @@ def from_params(cls, name: str, prefix: str, vector_dims: int):
"type": "vector",
"attrs": {
"dims": vector_dims,
"datatype": "float32",
"datatype": dtype,
"distance_metric": "cosine",
"algorithm": "flat",
},
Expand Down
13 changes: 9 additions & 4 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def __init__(
]

# Create semantic cache schema and index
schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims)
dtype = kwargs.get("dtype", "float32")
schema = SemanticCacheIndexSchema.from_params(
name, prefix, vectorizer.dims, dtype
)
schema = self._modify_schema(schema, filterable_fields)
self._index = SearchIndex(schema=schema)

Expand Down Expand Up @@ -137,6 +140,7 @@ def __init__(
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
)
self._vectorizer = vectorizer
self._dtype = self.index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr]

def _modify_schema(
self,
Expand Down Expand Up @@ -286,7 +290,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")

return self._vectorizer.embed(prompt)
return self._vectorizer.embed(prompt, dtype=self._dtype)

async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
"""Converts a text prompt to its vector representation using the
Expand Down Expand Up @@ -368,6 +372,7 @@ def check(
num_results=num_results,
return_score=True,
filter_expression=filter_expression,
dtype=self._dtype,
)

# Search the cache!
Expand Down Expand Up @@ -538,7 +543,7 @@ def store(
# Load cache entry with TTL
ttl = ttl or self._ttl
keys = self._index.load(
data=[cache_entry.to_dict()],
data=[cache_entry.to_dict(self._dtype)],
ttl=ttl,
id_field=ENTRY_ID_FIELD_NAME,
)
Expand Down Expand Up @@ -602,7 +607,7 @@ async def astore(
# Load cache entry with TTL
ttl = ttl or self._ttl
keys = await aindex.load(
data=[cache_entry.to_dict()],
data=[cache_entry.to_dict(self._dtype)],
ttl=ttl,
id_field=ENTRY_ID_FIELD_NAME,
)
Expand Down
8 changes: 4 additions & 4 deletions redisvl/extensions/router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic.v1 import BaseModel, Field, validator

from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
from redisvl.schema import IndexInfo, IndexSchema
from redisvl.schema import IndexSchema


class Route(BaseModel):
Expand Down Expand Up @@ -89,7 +89,7 @@ class SemanticRouterIndexSchema(IndexSchema):
"""Customized index schema for SemanticRouter."""

@classmethod
def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema":
def from_params(cls, name: str, vector_dims: int, dtype: str):
"""Create an index schema based on router name and vector dimensions.
Args:
Expand All @@ -100,7 +100,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema"
SemanticRouterIndexSchema: The constructed index schema.
"""
return cls(
index=IndexInfo(name=name, prefix=name),
index={"name": name, "prefix": name}, # type: ignore
fields=[ # type: ignore
{"name": "route_name", "type": "tag"},
{"name": "reference", "type": "text"},
Expand All @@ -111,7 +111,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema"
"algorithm": "flat",
"dims": vector_dims,
"distance_metric": "cosine",
"datatype": "float32",
"datatype": dtype,
},
},
],
Expand Down
28 changes: 24 additions & 4 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,42 @@ def __init__(
vectorizer=vectorizer,
routing_config=routing_config,
)
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
dtype = kwargs.get("dtype", "float32")
self._initialize_index(
redis_client, redis_url, overwrite, dtype, **connection_kwargs
)

def _initialize_index(
self,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
dtype: str = "float32",
**connection_kwargs,
):
"""Initialize the search index and handle Redis connection."""
schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims)
schema = SemanticRouterIndexSchema.from_params(
self.name, self.vectorizer.dims, dtype
)
self._index = SearchIndex(schema=schema)

if redis_client:
self._index.set_client(redis_client)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)

# Check for existing router index
existed = self._index.exists()
self._index.create(overwrite=overwrite)
if not overwrite and existed:
existing_index = SearchIndex.from_existing(
self.name, redis_client=self._index.client
)
if existing_index.schema != self._index.schema:
raise ValueError(
f"Existing index {self.name} schema does not match the user provided schema for the semantic router. "
"If you wish to overwrite the index schema, set overwrite=True during initialization."
)
self._index.create(overwrite=overwrite, drop=False)

if not existed or overwrite:
# write the routes to Redis
Expand Down Expand Up @@ -153,7 +169,9 @@ def _add_routes(self, routes: List[Route]):
for route in routes:
# embed route references as a single batch
reference_vectors = self.vectorizer.embed_many(
[reference for reference in route.references], as_buffer=True
[reference for reference in route.references],
as_buffer=True,
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)
# set route references
for i, reference in enumerate(route.references):
Expand Down Expand Up @@ -230,6 +248,7 @@ def _classify_route(
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
distance_threshold=distance_threshold,
return_fields=["route_name"],
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)

aggregate_request = self._build_aggregate_request(
Expand Down Expand Up @@ -282,6 +301,7 @@ def _classify_multi_route(
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
distance_threshold=distance_threshold,
return_fields=["route_name"],
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)
aggregate_request = self._build_aggregate_request(
vector_range_query, aggregation_method, max_k
Expand Down
9 changes: 4 additions & 5 deletions redisvl/extensions/session_manager/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ def generate_id(cls, values):
)
return values

def to_dict(self) -> Dict:
def to_dict(self, dtype: Optional[str] = None) -> Dict:
data = self.dict(exclude_none=True)

# handle optional fields
if SESSION_VECTOR_FIELD_NAME in data:
data[SESSION_VECTOR_FIELD_NAME] = array_to_buffer(
data[SESSION_VECTOR_FIELD_NAME]
data[SESSION_VECTOR_FIELD_NAME], dtype # type: ignore[arg-type]
)

return data


Expand All @@ -80,7 +79,7 @@ def from_params(cls, name: str, prefix: str):
class SemanticSessionIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
Expand All @@ -95,7 +94,7 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
"type": "vector",
"attrs": {
"dims": vectorizer_dims,
"datatype": "float32",
"datatype": dtype,
"distance_metric": "cosine",
"algorithm": "flat",
},
Expand Down
Loading

0 comments on commit 42e33a3

Please sign in to comment.