Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added AmazonBedrockVectorizer class #1151

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions dsp/modules/sentence_vectorizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from typing import List, Optional

import boto3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update with try/except block like in aws_providers.

import numpy as np
import openai

Expand Down Expand Up @@ -247,4 +247,93 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
texts_to_vectorize = self._extract_text_from_examples(inp_examples)
embeddings = self._model.embed(texts_to_vectorize, batch_size=self._batch_size, parallel=self._parallel)

return np.array([embedding.tolist() for embedding in embeddings], dtype=np.float32)
return np.array([embedding.tolist() for embedding in embeddings], dtype=np.float32)

class AmazonBedrockVectorizer(BaseSentenceVectorizer):
'''
This vectorizer uses Amazon Bedrock API to convert texts to embeddings.
'''
SUPPORTED_MODELS = [
"amazon.titan-embed-text-v1", "amazon.titan-embed-text-v2:0",
"cohere.embed-english-v3", "cohere.embed-multilingual-v3"
]

def __init__(
self,
model_id: str = 'amazon.titan-embed-text-v2:0',
embed_batch_size: int = 128,
region_name: str = 'us-west-2',
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
):
self.model_id = model_id
self.embed_batch_size = embed_batch_size

# Initialize the Bedrock client
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this potentially be replaced with the existing DSPy AWS model integrations? would be neat to tie it in!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not exactly sure what you would like to see here to be honest 🤔

self.bedrock_client = boto3.client(
service_name='bedrock-runtime',
region_name=region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
)

def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
text_to_vectorize = self._extract_text_from_examples(inp_examples)
embeddings_list = []

n_batches = (len(text_to_vectorize) - 1) // self.embed_batch_size + 1
for cur_batch_idx in range(n_batches):
start_idx = cur_batch_idx * self.embed_batch_size
end_idx = (cur_batch_idx + 1) * self.embed_batch_size
cur_batch = text_to_vectorize[start_idx: end_idx]

# Configure Bedrock API Body
if self.model_id not in self.SUPPORTED_MODELS:
raise Exception(f"Unsupported model: {self.model_id}")

if self.model_id == "amazon.titan-embed-text-v1":
if self.embed_batch_size == 1:
body = json.dumps({"inputText": cur_batch[0]})
else:
raise Exception(f"Model {self.model_id} supports batch size of 1 exclusively.")
elif self.model_id == "amazon.titan-embed-text-v2:0":
if self.embed_batch_size == 1:
body = json.dumps({
"inputText": cur_batch[0],
"dimensions": 512
})
else:
raise Exception(f"Model {self.model_id} supports batch size of 1 exclusively.")
elif self.model_id.startswith("cohere.embed"):
body = json.dumps({
"texts": cur_batch,
"input_type": "search_document"
})
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would appreciate a more informative exception message :).

raise Exception("How did you even get here?")


# Invoke Bedrock API
response = self.bedrock_client.invoke_model(
body=body,
modelId=self.model_id,
accept='application/json',
contentType='application/json'
)

response_body = json.loads(response['body'].read())
if self.model_id.startswith("cohere.embed"):
cur_batch_embeddings = response_body['embeddings']
elif self.model_id.startswith("amazon.titan-embed-text"):
cur_batch_embeddings = response_body['embedding']
else:
raise Exception(f"Not implemented yet! Check the format of response for model {self.model_id}from the Amazon Bedrock documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html")
embeddings_list.extend(cur_batch_embeddings)

embeddings = np.array(embeddings_list, dtype=np.float32)
return embeddings

def _extract_text_from_examples(self, inp_examples: List) -> List[str]:
if isinstance(inp_examples[0], str):
return inp_examples
return [" ".join([example[key] for key in example._input_keys]) for example in inp_examples]
Loading