Skip to content

Commit

Permalink
use appropriate pydantic shim
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Apr 27, 2024
1 parent d671019 commit 6388c4b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 33 deletions.
2 changes: 1 addition & 1 deletion redisvl/utils/rerank/cohere.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic import PrivateAttr
from pydantic.v1 import PrivateAttr

from redisvl.utils.rerank.base import BaseReranker

Expand Down
13 changes: 3 additions & 10 deletions redisvl/utils/vectorize/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional
from typing import Callable, List, Optional

from pydantic.v1 import BaseModel, validator
from pydantic.v1 import BaseModel

from redisvl.redis.utils import array_to_buffer


class BaseVectorizer(BaseModel, ABC):
model: str
dims: int

@validator("dims", pre=True)
@classmethod
def check_dims(cls, v):
if v <= 0:
raise ValueError("Dimension must be a positive integer")
return v
dims: Optional[int]

@abstractmethod
def embed_many(
Expand Down
32 changes: 19 additions & 13 deletions redisvl/utils/vectorize/text/azureopenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type

from pydantic.v1 import PrivateAttr
from redisvl.utils.vectorize.base import BaseVectorizer

# ignore that openai isn't imported
Expand Down Expand Up @@ -47,7 +47,8 @@ class AzureOpenAITextVectorizer(BaseVectorizer):
"""

aclient: Any # Since the OpenAI module is loaded dynamically
_client: Any = PrivateAttr()
_aclient: Any = PrivateAttr()

def __init__(
self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None
Expand All @@ -65,6 +66,14 @@ def __init__(
ImportError: If the openai library is not installed.
ValueError: If the AzureOpenAI API key, version, or endpoint are not provided.
"""
self._initialize_clients(api_config)
super().__init__(model=model, dims=self._set_model_dims(model))

def _initialize_clients(self, api_config: Optional[Dict]):
"""
Setup the OpenAI clients using the provided API key or an
environment variable.
"""
# Dynamic import of the openai module
try:
from openai import AsyncAzureOpenAI, AzureOpenAI
Expand Down Expand Up @@ -114,20 +123,17 @@ def __init__(
environment variable."
)

client = AzureOpenAI(
self._client = AzureOpenAI(
api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint
)
dims = self._set_model_dims(client, model)
super().__init__(model=model, dims=dims, client=client)
self.aclient = AsyncAzureOpenAI(
self._aclient = AsyncAzureOpenAI(
api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint
)

@staticmethod
def _set_model_dims(client, model) -> int:
def _set_model_dims(self, model) -> int:
try:
embedding = (
client.embeddings.create(input=["dimension test"], model=model)
self._client.embeddings.create(input=["dimension test"], model=model)
.data[0]
.embedding
)
Expand Down Expand Up @@ -175,7 +181,7 @@ def embed_many(

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self.client.embeddings.create(input=batch, model=self.model)
response = self._client.embeddings.create(input=batch, model=self.model)
embeddings += [
self._process_embedding(r.embedding, as_buffer) for r in response.data
]
Expand Down Expand Up @@ -213,7 +219,7 @@ def embed(

if preprocess:
text = preprocess(text)
result = self.client.embeddings.create(input=[text], model=self.model)
result = self._client.embeddings.create(input=[text], model=self.model)
return self._process_embedding(result.data[0].embedding, as_buffer)

@retry(
Expand Down Expand Up @@ -253,7 +259,7 @@ async def aembed_many(

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = await self.aclient.embeddings.create(
response = await self._aclient.embeddings.create(
input=batch, model=self.model
)
embeddings += [
Expand Down Expand Up @@ -293,5 +299,5 @@ async def aembed(

if preprocess:
text = preprocess(text)
result = await self.aclient.embeddings.create(input=[text], model=self.model)
result = await self._aclient.embeddings.create(input=[text], model=self.model)
return self._process_embedding(result.data[0].embedding, as_buffer)
26 changes: 22 additions & 4 deletions redisvl/utils/vectorize/text/cohere.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Any, Callable, Dict, List, Optional

from pydantic import PrivateAttr
from pydantic.v1 import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type

Expand Down Expand Up @@ -90,7 +90,6 @@ def _initialize_client(self, api_config: Optional[Dict]):
"Provide it in api_config or set the COHERE_API_KEY environment variable."
)
self._client = Client(api_key=api_key, client_name="redisvl")
self._aclient = AsyncClient(api_key=api_key, client_name="redisvl")

def _set_model_dims(self, model) -> int:
try:
Expand Down Expand Up @@ -158,7 +157,7 @@ def embed(
)
if preprocess:
text = preprocess(text)
embedding = self.client.embed(
embedding = self._client.embed(
texts=[text], model=self.model, input_type=input_type
).embeddings[0]
return self._process_embedding(embedding, as_buffer)
Expand Down Expand Up @@ -227,11 +226,30 @@ def embed_many(

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self.client.embed(
response = self._client.embed(
texts=batch, model=self.model, input_type=input_type
)
embeddings += [
self._process_embedding(embedding, as_buffer)
for embedding in response.embeddings
]
return embeddings

async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 1000,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
raise NotImplementedError

async def aembed(
self,
text: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
raise NotImplementedError
25 changes: 22 additions & 3 deletions redisvl/utils/vectorize/text/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, List, Optional

from pydantic import PrivateAttr
from pydantic.v1 import PrivateAttr

from redisvl.utils.vectorize.base import BaseVectorizer

Expand Down Expand Up @@ -99,7 +99,7 @@ def embed(

if preprocess:
text = preprocess(text)
embedding = self.client.encode([text])[0]
embedding = self._client.encode([text])[0]
return self._process_embedding(embedding.tolist(), as_buffer)

def embed_many(
Expand Down Expand Up @@ -135,11 +135,30 @@ def embed_many(

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
batch_embeddings = self.client.encode(batch)
batch_embeddings = self._client.encode(batch)
embeddings.extend(
[
self._process_embedding(embedding.tolist(), as_buffer)
for embedding in batch_embeddings
]
)
return embeddings

async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 1000,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
raise NotImplementedError

async def aembed(
self,
text: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
raise NotImplementedError
2 changes: 1 addition & 1 deletion redisvl/utils/vectorize/text/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Any, Callable, Dict, List, Optional

from pydantic import PrivateAttr
from pydantic.v1 import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type

Expand Down
21 changes: 20 additions & 1 deletion redisvl/utils/vectorize/text/vertexai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Any, Callable, Dict, List, Optional

from pydantic import PrivateAttr
from pydantic.v1 import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type

Expand Down Expand Up @@ -193,3 +193,22 @@ def embed(
text = preprocess(text)
result = self._client.get_embeddings([text])
return self._process_embedding(result[0].values, as_buffer)

async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 1000,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
raise NotImplementedError

async def aembed(
self,
text: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
raise NotImplementedError

0 comments on commit 6388c4b

Please sign in to comment.