From fa08eb412e5c003d58fb63ca0c9defebd8f9bae8 Mon Sep 17 00:00:00 2001 From: Jonah Sussman <42743659+JonahSussman@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:17:15 -0400 Subject: [PATCH] :bug: Fixed Amazon Bedrock dependency (#361) * Fixed aws dependency Signed-off-by: JonahSussman * Fixed class name Signed-off-by: JonahSussman --------- Signed-off-by: JonahSussman --- kai/service/llm_interfacing/model_provider.py | 5 +++-- pyproject.toml | 1 + requirements.txt | 20 ++++++++++++------- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/kai/service/llm_interfacing/model_provider.py b/kai/service/llm_interfacing/model_provider.py index ad6cde46..7f67bac3 100644 --- a/kai/service/llm_interfacing/model_provider.py +++ b/kai/service/llm_interfacing/model_provider.py @@ -3,7 +3,8 @@ from genai import Client, Credentials from genai.extensions.langchain.chat_llm import LangChainChatInterface from genai.schema import DecodingMethod -from langchain_community.chat_models import BedrockChat, ChatOllama, ChatOpenAI +from langchain_aws import ChatBedrock +from langchain_community.chat_models import ChatOllama, ChatOpenAI from langchain_community.chat_models.fake import FakeListChatModel from langchain_core.language_models.chat_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI @@ -89,7 +90,7 @@ def __init__(self, config: KaiConfigModels): model_id = model_args["model_id"] case "ChatBedrock": - model_class = BedrockChat + model_class = ChatBedrock defaults = { "model_id": "meta.llama3-70b-instruct-v1:0", diff --git a/pyproject.toml b/pyproject.toml index ce68c9a9..91d130e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "langchain-community==0.2.10", "langchain-openai==0.1.23", "langchain-google-genai==1.0.9", + "langchain-aws==0.1.18", "langchain-experimental==0.0.64", "gunicorn==22.0.0", "tree-sitter==0.22.3", diff --git a/requirements.txt b/requirements.txt index a8be58e1..4cb819fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,7 +50,9 @@ beautifulsoup4==4.12.3 bleach==6.1.0 # via nbconvert boto3==1.34.157 - # via kai (pyproject.toml) + # via + # kai (pyproject.toml) + # langchain-aws botocore==1.34.162 # via # boto3 @@ -266,6 +268,8 @@ langchain==0.2.11 # via # kai (pyproject.toml) # langchain-community +langchain-aws==0.1.18 + # via kai (pyproject.toml) langchain-community==0.2.10 # via # kai (pyproject.toml) @@ -273,6 +277,7 @@ langchain-community==0.2.10 langchain-core==0.2.39 # via # langchain + # langchain-aws # langchain-community # langchain-experimental # langchain-google-genai @@ -286,7 +291,7 @@ langchain-openai==0.1.23 # via kai (pyproject.toml) langchain-text-splitters==0.2.4 # via langchain -langsmith==0.1.117 +langsmith==0.1.119 # via # langchain # langchain-community @@ -335,8 +340,9 @@ notebook-shim==0.2.4 numpy==1.26.4 # via # langchain + # langchain-aws # langchain-community -openai==1.44.1 +openai==1.45.0 # via langchain-openai orjson==3.10.7 # via langsmith @@ -392,11 +398,11 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.3 # via stack-data -pyasn1==0.6.0 +pyasn1==0.6.1 # via # pyasn1-modules # rsa -pyasn1-modules==0.4.0 +pyasn1-modules==0.4.1 # via google-auth pycparser==2.22 # via cffi @@ -460,7 +466,7 @@ referencing==0.35.1 # jsonschema # jsonschema-specifications # jupyter-events -regex==2024.7.24 +regex==2024.9.11 # via tiktoken requests==2.32.3 # via @@ -584,7 +590,7 @@ uri-template==1.3.0 # via jsonschema uritemplate==4.1.1 # via google-api-python-client -urllib3==2.2.2 +urllib3==2.2.3 # via # botocore # requests