Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.5.1 #135

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ libclang = "*"
clang = "*"
langchain = "*"
langchain-openai = "*"
langchain-community = "*"

[dev-packages]
pylint = "*"
Expand Down
2 changes: 1 addition & 1 deletion esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def main() -> None:
"-m",
"--ai-model",
default="",
help="Which AI model to use. Built-in models: {"
help="Which AI model to use. Built-in models: {OpenAI GPT models, "
+ ", ".join(_ai_model_names)
+ ", +custom models}",
)
Expand Down
95 changes: 79 additions & 16 deletions esbmc_ai/ai_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Author: Yiannis Charalambous

from abc import abstractmethod
from typing import Any, Iterable, Union
from typing import Any, Iterable, Optional, Union
from enum import Enum
from pydantic.v1.types import SecretStr
from typing_extensions import override
Expand Down Expand Up @@ -152,6 +152,7 @@ def create_llm(
requests_max_tries: int = 5,
requests_timeout: float = 60,
) -> BaseLanguageModel:
assert api_keys.openai, "No OpenAI api key has been specified..."
return ChatOpenAI(
model=self.name,
api_key=SecretStr(api_keys.openai),
Expand Down Expand Up @@ -267,11 +268,10 @@ def apply_chat_template(
)


class AIModels(Enum):
GPT_3 = AIModelOpenAI(name="gpt-3.5-turbo", tokens=4096)
GPT_3_16K = AIModelOpenAI(name="gpt-3.5-turbo-16k", tokens=16384)
GPT_4 = AIModelOpenAI(name="gpt-4", tokens=8192)
GPT_4_32K = AIModelOpenAI(name="gpt-4-32k", tokens=32768)
class _AIModels(Enum):
"""Private enum that contains predefined AI Models. OpenAI models are not
defined because they are fetched from the API."""

FALCON_7B = AIModelTextGen(
name="falcon-7b",
tokens=8192,
Expand All @@ -295,7 +295,7 @@ class AIModels(Enum):

_custom_ai_models: list[AIModel] = []

_ai_model_names: set[str] = set(item.value.name for item in AIModels)
_ai_model_names: set[str] = set(item.value.name for item in _AIModels)


def add_custom_ai_model(ai_model: AIModel) -> None:
Expand All @@ -307,19 +307,82 @@ def add_custom_ai_model(ai_model: AIModel) -> None:
_custom_ai_models.append(ai_model)


def is_valid_ai_model(ai_model: Union[str, AIModel]) -> bool:
"""Accepts both the AIModel object and the name as parameter."""
name: str
if isinstance(ai_model, AIModel):
name = ai_model.name
else:
name = ai_model
def is_valid_ai_model(
ai_model: Union[str, AIModel], api_keys: Optional[APIKeyCollection] = None
) -> bool:
"""Accepts both the AIModel object and the name as parameter. It checks the
openai servers to see if a model is defined on their servers, if not, then
it checks the internally defined AI models list."""

# Get the name of the model
name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model

# Try accessing openai api and checking if there is a model defined.
# NOTE: This is not tested as no way to mock API currently.
if api_keys and api_keys.openai:
try:
from openai import Client

for model in Client(api_key=api_keys.openai).models.list().data:
if model.id == name:
return True
except ImportError:
pass

# Use the predefined list of models.
return name in _ai_model_names


def get_ai_model_by_name(name: str) -> AIModel:
def _get_openai_model_max_tokens(name: str) -> int:
"""NOTE: OpenAI currently does not expose an API for getting the model
length. Maybe add a config input value for this?"""

# https://platform.openai.com/docs/models
tokens = {
"gpt-4o": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-instruct": 4096,
}

# Split into - segments and remove each section from the end to find out
# which one matches the most.

# Base Case
if name in tokens:
return tokens[name]

# Step Case
name_split: list[str] = name.split("-")
for i in range(1, name.count("-")):
subname: str = "-".join(name_split[:-i])
if subname in tokens:
return tokens[subname]

raise ValueError(f"Could not figure out max tokens for model: {name}")


def get_ai_model_by_name(
name: str, api_keys: Optional[APIKeyCollection] = None
) -> AIModel:
# Check OpenAI models.
if api_keys and api_keys.openai:
try:
from openai import Client

for model in Client(api_key=api_keys.openai).models.list().data:
if model.id == name:
add_custom_ai_model(
AIModelOpenAI(
model.id,
_get_openai_model_max_tokens(model.id),
),
)
except ImportError:
pass

# Check AIModels enum.
for enum_value in AIModels:
for enum_value in _AIModels:
ai_model: AIModel = enum_value.value
if name == ai_model.name:
return ai_model
Expand Down
6 changes: 3 additions & 3 deletions esbmc_ai/api_key_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

"""API Key Collection definition."""

from typing import NamedTuple
from typing import NamedTuple, Optional


class APIKeyCollection(NamedTuple):
"""Class that is used to pass keys to AIModels."""

openai: str = ""
huggingface: str = ""
openai: Optional[str]
huggingface: Optional[str]
11 changes: 5 additions & 6 deletions esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

temp_auto_clean: bool = True
temp_file_dir: str = "."
ai_model: AIModel = AIModels.GPT_3.value
ai_model: AIModel

esbmc_output_type: str = "full"
source_code_format: str = "full"
Expand Down Expand Up @@ -455,11 +455,10 @@ def load_config(file_path: str) -> None:
ai_model_name, _ = _load_config_value(
config_file,
"ai_model",
ai_model,
)
if is_valid_ai_model(ai_model_name):
if is_valid_ai_model(ai_model_name, api_keys):
# Load the ai_model from loaded models.
ai_model = get_ai_model_by_name(ai_model_name)
ai_model = get_ai_model_by_name(ai_model_name, api_keys)
else:
print(f"Error: {ai_model_name} is not a valid AI model")
sys.exit(4)
Expand All @@ -484,8 +483,8 @@ def load_args(args) -> None:

global ai_model
if args.ai_model != "":
if is_valid_ai_model(args.ai_model):
ai_model = get_ai_model_by_name(args.ai_model)
if is_valid_ai_model(args.ai_model, api_keys):
ai_model = get_ai_model_by_name(args.ai_model, api_keys)
else:
print(f"Error: invalid --ai-model parameter {args.ai_model}")
sys.exit(4)
Expand Down
Loading
Loading