diff --git a/Pipfile b/Pipfile index 3d0e65e..c739635 100644 --- a/Pipfile +++ b/Pipfile @@ -24,6 +24,7 @@ libclang = "*" clang = "*" langchain = "*" langchain-openai = "*" +langchain-community = "*" [dev-packages] pylint = "*" diff --git a/esbmc_ai/__main__.py b/esbmc_ai/__main__.py index f4ce4a9..8c528d3 100755 --- a/esbmc_ai/__main__.py +++ b/esbmc_ai/__main__.py @@ -218,7 +218,7 @@ def main() -> None: "-m", "--ai-model", default="", - help="Which AI model to use. Built-in models: {" + help="Which AI model to use. Built-in models: {OpenAI GPT models, " + ", ".join(_ai_model_names) + ", +custom models}", ) diff --git a/esbmc_ai/ai_models.py b/esbmc_ai/ai_models.py index 1874bde..16120f7 100644 --- a/esbmc_ai/ai_models.py +++ b/esbmc_ai/ai_models.py @@ -1,7 +1,7 @@ # Author: Yiannis Charalambous from abc import abstractmethod -from typing import Any, Iterable, Union +from typing import Any, Iterable, Optional, Union from enum import Enum from pydantic.v1.types import SecretStr from typing_extensions import override @@ -152,6 +152,7 @@ def create_llm( requests_max_tries: int = 5, requests_timeout: float = 60, ) -> BaseLanguageModel: + assert api_keys.openai, "No OpenAI api key has been specified..." return ChatOpenAI( model=self.name, api_key=SecretStr(api_keys.openai), @@ -267,11 +268,10 @@ def apply_chat_template( ) -class AIModels(Enum): - GPT_3 = AIModelOpenAI(name="gpt-3.5-turbo", tokens=4096) - GPT_3_16K = AIModelOpenAI(name="gpt-3.5-turbo-16k", tokens=16384) - GPT_4 = AIModelOpenAI(name="gpt-4", tokens=8192) - GPT_4_32K = AIModelOpenAI(name="gpt-4-32k", tokens=32768) +class _AIModels(Enum): + """Private enum that contains predefined AI Models. OpenAI models are not + defined because they are fetched from the API.""" + FALCON_7B = AIModelTextGen( name="falcon-7b", tokens=8192, @@ -295,7 +295,7 @@ class AIModels(Enum): _custom_ai_models: list[AIModel] = [] -_ai_model_names: set[str] = set(item.value.name for item in AIModels) +_ai_model_names: set[str] = set(item.value.name for item in _AIModels) def add_custom_ai_model(ai_model: AIModel) -> None: @@ -307,19 +307,82 @@ def add_custom_ai_model(ai_model: AIModel) -> None: _custom_ai_models.append(ai_model) -def is_valid_ai_model(ai_model: Union[str, AIModel]) -> bool: - """Accepts both the AIModel object and the name as parameter.""" - name: str - if isinstance(ai_model, AIModel): - name = ai_model.name - else: - name = ai_model +def is_valid_ai_model( + ai_model: Union[str, AIModel], api_keys: Optional[APIKeyCollection] = None +) -> bool: + """Accepts both the AIModel object and the name as parameter. It checks the + openai servers to see if a model is defined on their servers, if not, then + it checks the internally defined AI models list.""" + + # Get the name of the model + name: str = ai_model.name if isinstance(ai_model, AIModel) else ai_model + + # Try accessing openai api and checking if there is a model defined. + # NOTE: This is not tested as no way to mock API currently. + if api_keys and api_keys.openai: + try: + from openai import Client + + for model in Client(api_key=api_keys.openai).models.list().data: + if model.id == name: + return True + except ImportError: + pass + + # Use the predefined list of models. return name in _ai_model_names -def get_ai_model_by_name(name: str) -> AIModel: +def _get_openai_model_max_tokens(name: str) -> int: + """NOTE: OpenAI currently does not expose an API for getting the model + length. Maybe add a config input value for this?""" + + # https://platform.openai.com/docs/models + tokens = { + "gpt-4o": 128000, + "gpt-4": 8192, + "gpt-3.5-turbo": 16385, + "gpt-3.5-turbo-instruct": 4096, + } + + # Split into - segments and remove each section from the end to find out + # which one matches the most. + + # Base Case + if name in tokens: + return tokens[name] + + # Step Case + name_split: list[str] = name.split("-") + for i in range(1, name.count("-")): + subname: str = "-".join(name_split[:-i]) + if subname in tokens: + return tokens[subname] + + raise ValueError(f"Could not figure out max tokens for model: {name}") + + +def get_ai_model_by_name( + name: str, api_keys: Optional[APIKeyCollection] = None +) -> AIModel: + # Check OpenAI models. + if api_keys and api_keys.openai: + try: + from openai import Client + + for model in Client(api_key=api_keys.openai).models.list().data: + if model.id == name: + add_custom_ai_model( + AIModelOpenAI( + model.id, + _get_openai_model_max_tokens(model.id), + ), + ) + except ImportError: + pass + # Check AIModels enum. - for enum_value in AIModels: + for enum_value in _AIModels: ai_model: AIModel = enum_value.value if name == ai_model.name: return ai_model diff --git a/esbmc_ai/api_key_collection.py b/esbmc_ai/api_key_collection.py index eba635d..9aeec6f 100644 --- a/esbmc_ai/api_key_collection.py +++ b/esbmc_ai/api_key_collection.py @@ -2,11 +2,11 @@ """API Key Collection definition.""" -from typing import NamedTuple +from typing import NamedTuple, Optional class APIKeyCollection(NamedTuple): """Class that is used to pass keys to AIModels.""" - openai: str = "" - huggingface: str = "" + openai: Optional[str] + huggingface: Optional[str] diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index 3ce23fa..44007b3 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -38,7 +38,7 @@ temp_auto_clean: bool = True temp_file_dir: str = "." -ai_model: AIModel = AIModels.GPT_3.value +ai_model: AIModel esbmc_output_type: str = "full" source_code_format: str = "full" @@ -455,11 +455,10 @@ def load_config(file_path: str) -> None: ai_model_name, _ = _load_config_value( config_file, "ai_model", - ai_model, ) - if is_valid_ai_model(ai_model_name): + if is_valid_ai_model(ai_model_name, api_keys): # Load the ai_model from loaded models. - ai_model = get_ai_model_by_name(ai_model_name) + ai_model = get_ai_model_by_name(ai_model_name, api_keys) else: print(f"Error: {ai_model_name} is not a valid AI model") sys.exit(4) @@ -484,8 +483,8 @@ def load_args(args) -> None: global ai_model if args.ai_model != "": - if is_valid_ai_model(args.ai_model): - ai_model = get_ai_model_by_name(args.ai_model) + if is_valid_ai_model(args.ai_model, api_keys): + ai_model = get_ai_model_by_name(args.ai_model, api_keys) else: print(f"Error: invalid --ai-model parameter {args.ai_model}") sys.exit(4) diff --git a/requirements.txt b/requirements.txt index 41a599e..583ea95 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,153 +1,49 @@ -i https://pypi.org/simple -accelerate==0.20.3 -aiohttp==3.8.4 -aiosignal==1.3.1 -anthropic==0.2.10 -anyio==3.7.0 ; python_version >= '3.7' -appdirs==1.4.4 -asgiref==3.7.2 ; python_version >= '3.7' -async-timeout==4.0.2 -attrs==23.1.0 -bentoml[grpc,io]==1.0.22 ; python_version >= '3.7' -build==0.10.0 ; python_version >= '3.7' -cattrs==23.1.2 ; python_version >= '3.7' -certifi==2022.12.7 -charset-normalizer==3.1.0 -circus==0.18.0 ; python_version >= '3.7' -clang==16.0.1.1 -clarifai==9.1.0 -clarifai-grpc==9.5.0 ; python_version >= '3.8' -click==8.1.3 ; python_version >= '3.7' -click-option-group==0.5.6 ; python_version >= '3.6' and python_version < '4' -cloudpickle==2.2.1 ; python_version >= '3.6' -cmake==3.26.4 -cohere==3.10.0 -coloredlogs==15.0.1 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' -contextlib2==21.6.0 ; python_version >= '3.6' -dataclasses-json==0.5.8 ; python_version >= '3.6' -datasets==2.13.0 ; python_full_version >= '3.7.0' -deepmerge==1.1.0 -deprecated==1.2.14 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' -dill==0.3.6 ; python_version >= '3.7' -filelock==3.12.2 ; python_version >= '3.7' -filetype==1.2.0 -frozenlist==1.3.3 -fs==2.4.16 -fsspec==2023.6.0 ; python_version >= '3.8' -googleapis-common-protos==1.59.1 ; python_version >= '3.7' -greenlet==2.0.2 ; platform_machine == 'aarch64' or (platform_machine == 'ppc64le' or (platform_machine == 'x86_64' or (platform_machine == 'amd64' or (platform_machine == 'AMD64' or (platform_machine == 'win32' or platform_machine == 'WIN32'))))) -grpcio==1.54.2 -grpcio-health-checking==1.48.2 -h11==0.14.0 ; python_version >= '3.7' -httpcore==0.17.2 ; python_version >= '3.7' -httpx==0.24.1 ; python_version >= '3.7' -huggingface-hub==0.15.1 ; python_full_version >= '3.7.0' -humanfriendly==10.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' -idna==3.4 -importlib-metadata==6.0.1 ; python_version >= '3.7' -inflection==0.5.1 ; python_version >= '3.5' -jinja2==3.1.2 ; python_version >= '3.7' -langchain[llms]==0.0.209 -langchainplus-sdk==0.0.16 ; python_version < '4.0' and python_full_version >= '3.8.1' -libclang==16.0.0 -lit==16.0.6 -manifest-ml==0.0.1 -markdown-it-py==3.0.0 ; python_version >= '3.8' -markupsafe==2.1.3 ; python_version >= '3.7' -marshmallow==3.19.0 ; python_version >= '3.7' -marshmallow-enum==1.5.1 -mdurl==0.1.2 ; python_version >= '3.7' -mpmath==1.3.0 -multidict==6.0.4 -multiprocess==0.70.14 ; python_version >= '3.7' -mypy-extensions==1.0.0 ; python_version >= '3.5' -networkx==3.1 ; python_version >= '3.8' -nlpcloud==1.0.42 -numexpr==2.8.4 ; python_version >= '3.7' -numpy==1.25.0 ; python_version >= '3.9' -nvidia-cublas-cu11==11.10.3.66 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-cuda-cupti-cu11==11.7.101 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-cuda-nvrtc-cu11==11.7.99 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-cuda-runtime-cu11==11.7.99 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-cudnn-cu11==8.5.0.96 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-cufft-cu11==10.9.0.58 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-curand-cu11==10.2.10.91 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-cusolver-cu11==11.4.0.1 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-cusparse-cu11==11.7.4.91 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-nccl-cu11==2.14.3 ; platform_system == 'Linux' and platform_machine == 'x86_64' -nvidia-nvtx-cu11==11.7.91 ; platform_system == 'Linux' and platform_machine == 'x86_64' -openai==0.27.5 -openapi-schema-pydantic==1.2.4 ; python_full_version >= '3.6.1' -openllm==0.1.10 -openlm==0.0.5 -opentelemetry-api==1.17.0 ; python_version >= '3.7' -opentelemetry-instrumentation==0.38b0 ; python_version >= '3.7' -opentelemetry-instrumentation-aiohttp-client==0.38b0 ; python_version >= '3.7' -opentelemetry-instrumentation-asgi==0.38b0 ; python_version >= '3.7' -opentelemetry-instrumentation-grpc==0.38b0 -opentelemetry-sdk==1.17.0 ; python_version >= '3.7' -opentelemetry-semantic-conventions==0.38b0 ; python_version >= '3.7' -opentelemetry-util-http==0.38b0 ; python_version >= '3.7' -optimum==1.8.8 ; python_full_version >= '3.7.0' -orjson==3.9.1 ; python_version >= '3.7' -packaging==23.1 ; python_version >= '3.7' -pandas==2.0.2 -pathspec==0.11.1 ; python_version >= '3.7' -pillow==9.5.0 -pip==23.1.2 ; python_version >= '3.7' -pip-requirements-parser==32.0.1 ; python_full_version >= '3.6.0' -pip-tools==6.13.0 ; python_version >= '3.7' -prometheus-client==0.17.0 ; python_version >= '3.6' -protobuf==3.20.3 -psutil==5.9.5 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' -pyarrow==12.0.1 -pydantic==1.10.9 ; python_version >= '3.7' -pygments==2.15.1 ; python_version >= '3.7' -pynvml==11.5.0 ; python_version >= '3.6' -pyparsing==3.1.0 ; python_full_version >= '3.6.8' -pyproject-hooks==1.0.0 ; python_version >= '3.7' -python-dateutil==2.8.2 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' -python-dotenv==1.0.0 -python-json-logger==2.0.7 ; python_version >= '3.6' -python-multipart==0.0.6 ; python_version >= '3.7' -pytz==2023.3 -pyyaml==6.0 ; python_version >= '3.6' -pyzmq==25.1.0 ; python_version >= '3.6' -redis==4.5.5 ; python_version >= '3.7' -regex==2023.3.23 -requests==2.29.0 -rich==13.4.2 ; python_full_version >= '3.7.0' -safetensors==0.3.1 -schema==0.7.5 -sentencepiece==0.1.99 -setuptools==68.0.0 ; python_version >= '3.7' -simple-di==0.1.5 ; python_full_version >= '3.6.1' -six==1.16.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3' -sniffio==1.3.0 ; python_version >= '3.7' -sqlalchemy==2.0.16 ; python_version >= '3.7' -sqlitedict==2.1.0 -starlette==0.28.0 ; python_version >= '3.7' -sympy==1.12 ; python_version >= '3.8' -tabulate[widechars]==0.9.0 ; python_version >= '3.7' -tenacity==8.2.2 ; python_version >= '3.6' -text-generation==0.6.0 -tiktoken==0.3.3 -tokenizers==0.13.3 -torch==2.0.1 -torchvision==0.15.2 ; python_version >= '3.8' -tornado==6.3.2 ; python_version >= '3.8' -tqdm==4.65.0 -transformers==4.30.2 -triton==2.0.0 ; platform_system == 'Linux' and platform_machine == 'x86_64' -typing-extensions==4.6.3 ; python_version >= '3.7' +aiohttp==3.8.4; python_version >= '3.6' +aiosignal==1.3.1; python_version >= '3.7' +annotated-types==0.6.0; python_version >= '3.8' +anyio==4.3.0; python_version >= '3.8' +async-timeout==4.0.2; python_version >= '3.6' +attrs==23.1.0; python_version >= '3.7' +certifi==2022.12.7; python_version >= '3.6' +charset-normalizer==3.1.0; python_full_version >= '3.7.0' +clang==17.0.6 +dataclasses-json==0.6.4; python_version >= '3.7' and python_version < '4.0' +distro==1.9.0; python_version >= '3.6' +frozenlist==1.3.3; python_version >= '3.7' +greenlet==3.0.3; platform_machine == 'aarch64' or (platform_machine == 'ppc64le' or (platform_machine == 'x86_64' or (platform_machine == 'amd64' or (platform_machine == 'AMD64' or (platform_machine == 'win32' or platform_machine == 'WIN32'))))) +h11==0.14.0; python_version >= '3.7' +httpcore==1.0.5; python_version >= '3.8' +httpx==0.27.0; python_version >= '3.8' +idna==3.4; python_version >= '3.5' +jsonpatch==1.33; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6' +jsonpointer==2.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6' +langchain==0.1.16; python_version < '4.0' and python_full_version >= '3.8.1' +langchain-community==0.0.34; python_version < '4.0' and python_full_version >= '3.8.1' +langchain-core==0.1.45; python_version < '4.0' and python_full_version >= '3.8.1' +langchain-openai==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1' +langchain-text-splitters==0.0.1; python_version < '4.0' and python_full_version >= '3.8.1' +langsmith==0.1.49; python_version < '4.0' and python_full_version >= '3.8.1' +libclang==18.1.1 +marshmallow==3.21.1; python_version >= '3.8' +multidict==6.0.4; python_version >= '3.7' +mypy-extensions==1.0.0; python_version >= '3.5' +numpy==1.26.4; python_version >= '3.9' +openai==1.23.2; python_full_version >= '3.7.1' +orjson==3.10.1; python_version >= '3.8' +packaging==23.2; python_version >= '3.7' +pydantic==2.7.0; python_version >= '3.8' +pydantic-core==2.18.1; python_version >= '3.8' +python-dotenv==1.0.0; python_version >= '3.8' +pyyaml==6.0.1; python_version >= '3.6' +regex==2023.3.23; python_version >= '3.8' +requests==2.29.0; python_version >= '3.7' +sniffio==1.3.1; python_version >= '3.7' +sqlalchemy==2.0.29; python_version >= '3.7' +tenacity==8.2.3; python_version >= '3.7' +tiktoken==0.6.0; python_version >= '3.8' +tqdm==4.66.2; python_version >= '3.7' +typing-extensions==4.11.0; python_version >= '3.8' typing-inspect==0.9.0 -tzdata==2023.3 ; python_version >= '2' -urllib3==1.26.15 -uvicorn==0.22.0 ; python_version >= '3.7' -watchfiles==0.19.0 ; python_version >= '3.7' -wcwidth==0.2.6 -wheel==0.40.0 ; python_version >= '3.7' -wrapt==1.15.0 ; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4' -xxhash==3.2.0 ; python_version >= '3.6' -yarl==1.9.2 -zipp==3.15.0 ; python_version >= '3.7' +urllib3==1.26.15; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' +yarl==1.9.2; python_version >= '3.7' diff --git a/tests/test_ai_models.py b/tests/test_ai_models.py index e9bdba6..028bff3 100644 --- a/tests/test_ai_models.py +++ b/tests/test_ai_models.py @@ -13,16 +13,18 @@ add_custom_ai_model, is_valid_ai_model, AIModel, - AIModels, + _AIModels, get_ai_model_by_name, AIModelTextGen, + _get_openai_model_max_tokens, ) +"""TODO Find a way to mock the OpenAI API and test GPT LLM code.""" + def test_is_valid_ai_model() -> None: - assert is_valid_ai_model(AIModels.FALCON_7B.value) - assert is_valid_ai_model(AIModels.GPT_3_16K.value) - assert is_valid_ai_model("gpt-3.5-turbo") + assert is_valid_ai_model(_AIModels.FALCON_7B.value) + assert is_valid_ai_model(_AIModels.STARCHAT_BETA.value) assert is_valid_ai_model("falcon-7b") @@ -56,7 +58,7 @@ def test_add_custom_ai_model() -> None: def test_get_ai_model_by_name() -> None: # Try with first class AI - assert get_ai_model_by_name("gpt-3.5-turbo") + assert get_ai_model_by_name("falcon-7b") # Try with custom AI. # Add custom AI model if not added by previous tests. @@ -141,3 +143,15 @@ def test_escape_messages() -> None: assert result[3] == filtered[3] assert result[4] == filtered[4] assert result[5] == filtered[5] + + +def test__get_openai_model_max_tokens() -> None: + assert _get_openai_model_max_tokens("gpt-4o") == 128000 + assert _get_openai_model_max_tokens("gpt-4-turbo") == 8192 + assert _get_openai_model_max_tokens("gpt-3.5-turbo") == 16385 + assert _get_openai_model_max_tokens("gpt-3.5-turbo-instruct") == 4096 + assert _get_openai_model_max_tokens("gpt-3.5-turbo-aaaaaa") == 16385 + assert _get_openai_model_max_tokens("gpt-3.5-turbo-instruct-bbb") == 4096 + + with raises(ValueError): + _get_openai_model_max_tokens("aaaaa")