Skip to content

Commit

Permalink
Merge pull request #134 from Yiannis128/133-allow-arbitrary-openai-llms
Browse files Browse the repository at this point in the history
Added support for arbitrary OpenAI models.
  • Loading branch information
Yiannis128 committed Jun 14, 2024
2 parents 0eea2c2 + cd6af05 commit d7afe05
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 182 deletions.
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ libclang = "*"
clang = "*"
langchain = "*"
langchain-openai = "*"
langchain-community = "*"

[dev-packages]
pylint = "*"
Expand Down
2 changes: 1 addition & 1 deletion esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def main() -> None:
"-m",
"--ai-model",
default="",
help="Which AI model to use. Built-in models: {"
help="Which AI model to use. Built-in models: {OpenAI GPT models, "
+ ", ".join(_ai_model_names)
+ ", +custom models}",
)
Expand Down
95 changes: 79 additions & 16 deletions esbmc_ai/ai_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Author: Yiannis Charalambous

from abc import abstractmethod
from typing import Any, Iterable, Union
from typing import Any, Iterable, Optional, Union
from enum import Enum
from pydantic.v1.types import SecretStr
from typing_extensions import override
Expand Down Expand Up @@ -152,6 +152,7 @@ def create_llm(
requests_max_tries: int = 5,
requests_timeout: float = 60,
) -> BaseLanguageModel:
assert api_keys.openai, "No OpenAI api key has been specified..."
return ChatOpenAI(
model=self.name,
api_key=SecretStr(api_keys.openai),
Expand Down Expand Up @@ -267,11 +268,10 @@ def apply_chat_template(
)


class AIModels(Enum):
GPT_3 = AIModelOpenAI(name="gpt-3.5-turbo", tokens=4096)
GPT_3_16K = AIModelOpenAI(name="gpt-3.5-turbo-16k", tokens=16384)
GPT_4 = AIModelOpenAI(name="gpt-4", tokens=8192)
GPT_4_32K = AIModelOpenAI(name="gpt-4-32k", tokens=32768)
class _AIModels(Enum):
"""Private enum that contains predefined AI Models. OpenAI models are not
defined because they are fetched from the API."""

FALCON_7B = AIModelTextGen(
name="falcon-7b",
tokens=8192,
Expand All @@ -295,7 +295,7 @@ class AIModels(Enum):

_custom_ai_models: list[AIModel] = []

_ai_model_names: set[str] = set(item.value.name for item in AIModels)
_ai_model_names: set[str] = set(item.value.name for item in _AIModels)


def add_custom_ai_model(ai_model: AIModel) -> None:
Expand All @@ -307,19 +307,82 @@ def add_custom_ai_model(ai_model: AIModel) -> None:
_custom_ai_models.append(ai_model)


def is_valid_ai_model(ai_model: Union[str, AIModel]) -> bool:
"""Accepts both the AIModel object and the name as parameter."""
name: str
if isinstance(ai_model, AIModel):
name = ai_model.name
else:
name = ai_model
def is_valid_ai_model(
ai_model: Union[str, AIModel], api_keys: Optional[APIKeyCollection] = None
) -> bool:
"""Accepts both the AIModel object and the name as parameter. It checks the
openai servers to see if a model is defined on their servers, if not, then
it checks the internally defined AI models list."""

# Get the name of the model
name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model

# Try accessing openai api and checking if there is a model defined.
# NOTE: This is not tested as no way to mock API currently.
if api_keys and api_keys.openai:
try:
from openai import Client

for model in Client(api_key=api_keys.openai).models.list().data:
if model.id == name:
return True
except ImportError:
pass

# Use the predefined list of models.
return name in _ai_model_names


def get_ai_model_by_name(name: str) -> AIModel:
def _get_openai_model_max_tokens(name: str) -> int:
"""NOTE: OpenAI currently does not expose an API for getting the model
length. Maybe add a config input value for this?"""

# https://platform.openai.com/docs/models
tokens = {
"gpt-4o": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-instruct": 4096,
}

# Split into - segments and remove each section from the end to find out
# which one matches the most.

# Base Case
if name in tokens:
return tokens[name]

# Step Case
name_split: list[str] = name.split("-")
for i in range(1, name.count("-")):
subname: str = "-".join(name_split[:-i])
if subname in tokens:
return tokens[subname]

raise ValueError(f"Could not figure out max tokens for model: {name}")


def get_ai_model_by_name(
name: str, api_keys: Optional[APIKeyCollection] = None
) -> AIModel:
# Check OpenAI models.
if api_keys and api_keys.openai:
try:
from openai import Client

for model in Client(api_key=api_keys.openai).models.list().data:
if model.id == name:
add_custom_ai_model(
AIModelOpenAI(
model.id,
_get_openai_model_max_tokens(model.id),
),
)
except ImportError:
pass

# Check AIModels enum.
for enum_value in AIModels:
for enum_value in _AIModels:
ai_model: AIModel = enum_value.value
if name == ai_model.name:
return ai_model
Expand Down
6 changes: 3 additions & 3 deletions esbmc_ai/api_key_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

"""API Key Collection definition."""

from typing import NamedTuple
from typing import NamedTuple, Optional


class APIKeyCollection(NamedTuple):
"""Class that is used to pass keys to AIModels."""

openai: str = ""
huggingface: str = ""
openai: Optional[str]
huggingface: Optional[str]
11 changes: 5 additions & 6 deletions esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

temp_auto_clean: bool = True
temp_file_dir: str = "."
ai_model: AIModel = AIModels.GPT_3.value
ai_model: AIModel

esbmc_output_type: str = "full"
source_code_format: str = "full"
Expand Down Expand Up @@ -455,11 +455,10 @@ def load_config(file_path: str) -> None:
ai_model_name, _ = _load_config_value(
config_file,
"ai_model",
ai_model,
)
if is_valid_ai_model(ai_model_name):
if is_valid_ai_model(ai_model_name, api_keys):
# Load the ai_model from loaded models.
ai_model = get_ai_model_by_name(ai_model_name)
ai_model = get_ai_model_by_name(ai_model_name, api_keys)
else:
print(f"Error: {ai_model_name} is not a valid AI model")
sys.exit(4)
Expand All @@ -484,8 +483,8 @@ def load_args(args) -> None:

global ai_model
if args.ai_model != "":
if is_valid_ai_model(args.ai_model):
ai_model = get_ai_model_by_name(args.ai_model)
if is_valid_ai_model(args.ai_model, api_keys):
ai_model = get_ai_model_by_name(args.ai_model, api_keys)
else:
print(f"Error: invalid --ai-model parameter {args.ai_model}")
sys.exit(4)
Expand Down
Loading

0 comments on commit d7afe05

Please sign in to comment.