Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for google gemini model. #307

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ services:
GENAI_KEY:
OPENAI_API_BASE:
OPENAI_API_KEY:
GOOGLE_API_KEY:
image: ${IMAGE:-quay.io/konveyor/kai}:${TAG:-stable}
volumes:
- ${PWD}:/podman_compose:rw,z
Expand Down
7 changes: 7 additions & 0 deletions kai/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,10 @@ model_id = "mistralai/mixtral-8x7b-instruct-v01"

[embeddings]
todo = true

# **Google Gemini Pro**
# [models]
# provider = "ChatGoogleGenerativeAI"

# [models.args]
# model = "gemini-pro"
13 changes: 13 additions & 0 deletions kai/service/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_community.chat_models import BedrockChat, ChatOllama, ChatOpenAI
from langchain_community.chat_models.fake import FakeListChatModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic.v1.utils import deep_update

from kai.models.kai_config import KaiConfigModels
Expand Down Expand Up @@ -122,6 +123,18 @@ def __init__(self, config: KaiConfigModels):
model_args = deep_update(defaults, config.args)
model_id = "fake-list-chat-model"

case "ChatGoogleGenerativeAI":
model_class = ChatGoogleGenerativeAI
api_key = os.getenv("GOOGLE_API_KEY", "dummy_value")
defaults = {
"model": "gemini-pro",
"temperature": 0.7,
"streaming": False,
"google_api_key": api_key,
}
model_args = deep_update(defaults, config.args)
model_id = model_args["model"]

case _:
raise Exception(f"Unrecognized provider '{config.provider}'")

Expand Down
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ibm-generative-ai==2.2.0
Jinja2==3.1.4
langchain==0.2.11
langchain-community==0.2.10
langchain-google-genai==1.0.9
langchain-openai==0.1.17
langchain-experimental==0.0.64
gunicorn==22.0.0
Expand Down
76 changes: 72 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async-lru==2.0.4
# via jupyterlab
async-timeout==4.0.3
# via -r requirements.in
attrs==24.2.0
attrs==23.2.0
# via
# aiohttp
# jsonschema
Expand All @@ -55,6 +55,8 @@ botocore==1.34.162
# via
# boto3
# s3transfer
cachetools==5.5.0
# via google-auth
certifi==2024.7.4
# via
# httpcore
Expand Down Expand Up @@ -105,14 +107,48 @@ gitdb==4.0.11
# via gitpython
gitpython==3.1.43
# via -r requirements.in
google-ai-generativelanguage==0.6.6
# via google-generativeai
google-api-core[grpc]==2.19.1
# via
# google-ai-generativelanguage
# google-api-python-client
# google-generativeai
google-api-python-client==2.142.0
# via google-generativeai
google-auth==2.34.0
# via
# google-ai-generativelanguage
# google-api-core
# google-api-python-client
# google-auth-httplib2
# google-generativeai
google-auth-httplib2==0.2.0
# via google-api-python-client
google-generativeai==0.7.2
# via langchain-google-genai
googleapis-common-protos==1.63.2
# via
# google-api-core
# grpcio-status
greenlet==3.0.3
# via sqlalchemy
grpcio==1.65.5
# via
# google-api-core
# grpcio-status
grpcio-status==1.62.3
# via google-api-core
gunicorn==22.0.0
# via -r requirements.in
h11==0.14.0
# via httpcore
httpcore==1.0.5
# via httpx
httplib2==0.22.0
# via
# google-api-python-client
# google-auth-httplib2
httpx==0.26.0
# via
# ibm-generative-ai
Expand Down Expand Up @@ -234,15 +270,18 @@ langchain-community==0.2.10
# via
# -r requirements.in
# langchain-experimental
langchain-core==0.2.33
langchain-core==0.2.34
# via
# langchain
# langchain-community
# langchain-experimental
# langchain-google-genai
# langchain-openai
# langchain-text-splitters
langchain-experimental==0.0.64
# via -r requirements.in
langchain-google-genai==1.0.9
# via -r requirements.in
langchain-openai==0.1.17
# via -r requirements.in
langchain-text-splitters==0.2.2
Expand Down Expand Up @@ -331,6 +370,18 @@ prompt-toolkit==3.0.47
# via
# ipython
# jupyter-console
proto-plus==1.24.0
# via
# google-ai-generativelanguage
# google-api-core
protobuf==4.25.4
# via
# google-ai-generativelanguage
# google-api-core
# google-generativeai
# googleapis-common-protos
# grpcio-status
# proto-plus
psutil==6.0.0
# via ipykernel
psycopg2-binary==2.9.9
Expand All @@ -341,11 +392,18 @@ ptyprocess==0.7.0
# terminado
pure-eval==0.2.3
# via stack-data
pyasn1==0.6.0
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.4.0
# via google-auth
pycparser==2.22
# via cffi
pydantic==2.8.2
# via
# -r requirements.in
# google-generativeai
# ibm-generative-ai
# langchain
# langchain-core
Expand All @@ -364,6 +422,8 @@ pygments==2.18.0
# jupyter-console
# nbconvert
# qtconsole
pyparsing==3.1.2
# via httplib2
python-dateutil==2.8.2
# via
# -r requirements.in
Expand Down Expand Up @@ -406,6 +466,7 @@ regex==2024.7.24
requests==2.32.3
# via
# -r requirements.in
# google-api-core
# jupyterlab-server
# langchain
# langchain-community
Expand All @@ -423,6 +484,8 @@ rpds-py==0.20.0
# via
# jsonschema
# referencing
rsa==4.9
# via google-auth
s3transfer==0.10.2
# via boto3
send2trash==1.8.3
Expand Down Expand Up @@ -473,8 +536,10 @@ tornado==6.4.1
# jupyterlab
# notebook
# terminado
tqdm==4.66.5
# via openai
tqdm==4.66.4
# via
# google-generativeai
# openai
traitlets==5.14.3
# via
# comm
Expand Down Expand Up @@ -504,6 +569,7 @@ types-python-dateutil==2.9.0.20240316
# via arrow
typing-extensions==4.12.2
# via
# google-generativeai
# langchain-core
# openai
# pydantic
Expand All @@ -517,6 +583,8 @@ unidiff==0.7.5
# via -r requirements.in
uri-template==1.3.0
# via jsonschema
uritemplate==4.1.1
# via google-api-python-client
urllib3==2.2.2
# via
# botocore
Expand Down
Loading