Skip to content

Commit

Permalink
Merge pull request #345 from aws-samples/xuhan-aics
Browse files Browse the repository at this point in the history
feat: support connection to custom opensearch domain
  • Loading branch information
IcyKallen committed Aug 15, 2024
2 parents 21dcb88 + 1c36f6c commit a6e4979
Show file tree
Hide file tree
Showing 7 changed files with 593 additions and 249 deletions.
1 change: 1 addition & 0 deletions source/infrastructure/lib/chat/chat-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
"secretsmanager:GetSecretValue",
"bedrock:*",
"lambda:InvokeFunction",
"secretmanager:GetSecretValue",
],
effect: iam.Effect.ALLOW,
resources: ["*"],
Expand Down
2 changes: 1 addition & 1 deletion source/infrastructure/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"clobber": "npx projen clobber",
"compile": "npx projen compile",
"default": "npx projen default",
"deploy": "npx projen deploy",
"deploy": "npx cdk deploy",
"destroy": "npx projen destroy",
"diff": "npx projen diff",
"eject": "npx projen eject",
Expand Down
174 changes: 100 additions & 74 deletions source/lambda/online/functions/functions_utils/retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,33 @@

os.environ["PYTHONUNBUFFERED"] = "1"
import logging

import boto3
import sys

from functions.functions_utils.retriever.utils.aos_retrievers import QueryDocumentKNNRetriever, QueryDocumentBM25Retriever, QueryQuestionRetriever
from functions.functions_utils.retriever.utils.reranker import BGEReranker, MergeReranker
from functions.functions_utils.retriever.utils.context_utils import retriever_results_format
from functions.functions_utils.retriever.utils.websearch_retrievers import GoogleRetriever

from langchain.retrievers import ContextualCompressionRetriever, AmazonKnowledgeBasesRetriever
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever

from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.schema.runnable import (
RunnableLambda,
RunnablePassthrough,
)
from common_logic.common_utils.lambda_invoke_utils import chatbot_lambda_call_wrapper
import boto3
from common_logic.common_utils.chatbot_utils import ChatbotManager
from common_logic.common_utils.lambda_invoke_utils import chatbot_lambda_call_wrapper
from functions.functions_utils.retriever.utils.aos_retrievers import (
QueryDocumentBM25Retriever,
QueryDocumentKNNRetriever,
QueryQuestionRetriever,
)
from functions.functions_utils.retriever.utils.context_utils import (
retriever_results_format,
)
from functions.functions_utils.retriever.utils.reranker import (
BGEReranker,
MergeReranker,
)
from functions.functions_utils.retriever.utils.websearch_retrievers import (
GoogleRetriever,
)
from langchain.retrievers import (
AmazonKnowledgeBasesRetriever,
ContextualCompressionRetriever,
)
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever

logger = logging.getLogger("retriever")
logger.setLevel(logging.INFO)
Expand All @@ -43,61 +51,66 @@
knowledgebase_client = boto3.client("bedrock-agent-runtime", region)
sm_client = boto3.client("sagemaker-runtime")

def get_bedrock_kb_retrievers(knowledge_base_id_list, top_k:int):

def get_bedrock_kb_retrievers(knowledge_base_id_list, top_k: int):
retriever_list = [
AmazonKnowledgeBasesRetriever(
knowledge_base_id=knowledge_base_id,
retrieval_config={"vectorSearchConfiguration": {"numberOfResults": top_k}})
retrieval_config={"vectorSearchConfiguration": {"numberOfResults": top_k}},
)
for knowledge_base_id in knowledge_base_id_list
]
return retriever_list

def get_websearch_retrievers(top_k:int):
retriever_list = [
GoogleRetriever(top_k)
]

def get_websearch_retrievers(top_k: int):
retriever_list = [GoogleRetriever(top_k)]
return retriever_list

def get_custom_qd_retrievers(config:dict,using_bm25=False):

def get_custom_qd_retrievers(config: dict, using_bm25=False):
qd_retriever = QueryDocumentKNNRetriever(**config)

if using_bm25:
bm25_retrievert = QueryDocumentBM25Retriever(
**{
"index_name": config['index_name'],
"using_whole_doc": config.get("using_whole_doc",False),
"context_num":config["context_num"],
"enable_debug": config.get('enable_debug',False)
}
)
**{
"index_name": config["index_name"],
"using_whole_doc": config.get("using_whole_doc", False),
"context_num": config["context_num"],
"enable_debug": config.get("enable_debug", False),
}
)
return [qd_retriever, bm25_retrievert]
return [qd_retriever]

def get_custom_qq_retrievers(config:dict):
qq_retriever = QueryQuestionRetriever(
model_type="vector",
**config
)

def get_custom_qq_retrievers(config: dict):
qq_retriever = QueryQuestionRetriever(model_type="vector", **config)
return [qq_retriever]


def get_whole_chain(retriever_list, reranker_config):
lotr = MergerRetriever(retrievers=retriever_list)
if len(reranker_config):
default_reranker_config = {
"enable_debug": False,
"target_model": "bge_reranker_model.tar.gz",
"top_k": 10
}
reranker_config = {**default_reranker_config, **reranker_config}
compressor = BGEReranker(**reranker_config)
else:
compressor = MergeReranker()
# if len(reranker_config):
# default_reranker_config = {
# "enable_debug": False,
# "target_model": "bge_reranker_model.tar.gz",
# "top_k": 10,
# }
# reranker_config = {**default_reranker_config, **reranker_config}
# compressor = BGEReranker(**reranker_config)
# else:
# compressor = MergeReranker()

# Disable Reranker for AICS Guidance
compressor = MergeReranker()

compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=lotr
)
whole_chain = RunnablePassthrough.assign(
docs=compression_retriever | RunnableLambda(retriever_results_format))
docs=compression_retriever | RunnableLambda(retriever_results_format)
)
return whole_chain


Expand All @@ -109,54 +122,67 @@ def get_whole_chain(retriever_list, reranker_config):
"bedrock_kb": get_bedrock_kb_retrievers,
}


def get_custom_retrievers(retriever):
return retriever_dict[retriever['index_type']](retriever)
return retriever_dict[retriever["index_type"]](retriever)


def lambda_handler(event, context=None):
logger.info(f"Retrieval event: {event}")
event_body = event
event_body["retrievers"].append(
{
"index_type": "qd",
"index_name": "test",
"vector_field": "sentence_vector",
"source_field": "source",
"text_field": "paragraph",
}
)
retriever_list = []
print(retriever_list)
for retriever in event_body["retrievers"]:
retriever_list.extend(get_custom_retrievers(retriever))
rerankers = event_body.get("rerankers", None)
if rerankers:
reranker_config = rerankers[0]["config"]
else:
reranker_config = {}

if len(retriever_list) > 0:
whole_chain = get_whole_chain(retriever_list, reranker_config)
else:
whole_chain = RunnablePassthrough.assign(docs = lambda x: [])
whole_chain = RunnablePassthrough.assign(docs=lambda x: [])
docs = whole_chain.invoke({"query": event_body["query"], "debug_info": {}})
return {"code":0, "result": docs}
return {"code": 0, "result": docs}


if __name__ == "__main__":
query = '''test'''
query = """test"""
event = {
"body":
json.dumps(
{
"retrievers": [
{
"type": "qq",
"index_ids": ["test"],
"config": {
"top_k": 10,
}
"body": json.dumps(
{
"retrievers": [
{
"type": "qq",
"index_ids": ["test"],
"config": {
"top_k": 10,
},
],
"rerankers": [
{
"type": "reranker",
"config": {
"enable_debug": False,
"target_model": "bge_reranker_model.tar.gz"
}
}
],
"query": query
}
)
},
],
"rerankers": [
{
"type": "reranker",
"config": {
"enable_debug": False,
"target_model": "bge_reranker_model.tar.gz",
},
}
],
"query": query,
}
)
}
response = lambda_handler(event, None)
print(response)
print(response)
Loading

0 comments on commit a6e4979

Please sign in to comment.