Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/litellm generator #572

Merged
merged 19 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2fc5301
build: add litellm to project requirements
Tien-Cheng Mar 24, 2024
dbdfbea
feat(litellm.py): implement litellm based garak generator
Tien-Cheng Mar 24, 2024
9e7e85d
fix: ensure cli validates that model name is provided for litellm gen…
Tien-Cheng Mar 24, 2024
530b5a6
fix(litellm.py): supress logs generated by litellm
Tien-Cheng Mar 24, 2024
a1d7eb5
fix(litellm.py): fix custom provider in config not being provided to…
Tien-Cheng Mar 24, 2024
00ad740
fix(litellm.py): raise error if openai api key not supplied when usin…
Tien-Cheng Mar 24, 2024
a17477f
fix(litellm.py): fix bug where certain providers did not support mult…
Tien-Cheng Mar 24, 2024
3c5ef22
test(test_litellm.py): add unit tests for LiteLLMGenerator
Tien-Cheng Mar 24, 2024
85de0a2
fix(litellm.py): fix syntax error caused by typo
Tien-Cheng Mar 24, 2024
ae4e5af
docs(litellm.py): add docstring
Tien-Cheng Mar 24, 2024
1496362
Merge branch 'main' into feature/litellm-generator
leondz Apr 5, 2024
bac5cee
project toml syntax fix
leondz Apr 10, 2024
8f95f23
add to docs; reject if no provider set
leondz Apr 10, 2024
c2111c8
factor literal out to variable
leondz Apr 10, 2024
5a711b4
factor literal out to variable
leondz Apr 10, 2024
ce1a7ec
sync signature w base class
leondz Apr 10, 2024
4fcfaea
raise exception if litellm request but no config is set
leondz Apr 10, 2024
0a4c025
Merge branch 'feature/litellm-generator' of https://github.com/Tien-C…
leondz Apr 10, 2024
7e6fedb
it's fine to run litellm without a config (in some cases)
leondz Apr 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion garak/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ def main(arguments=[]) -> None:
# model is specified, we're doing something
elif _config.plugins.model_type:
if (
_config.plugins.model_type in ("openai", "replicate", "ggml", "huggingface")
_config.plugins.model_type
in ("openai", "replicate", "ggml", "huggingface", "litellm")
leondz marked this conversation as resolved.
Show resolved Hide resolved
and not _config.plugins.model_name
):
message = f"⚠️ Model type '{_config.plugins.model_type}' also needs a model name\n You can set one with e.g. --model_name \"billwurtz/gpt-1.0\""
Expand Down
5 changes: 4 additions & 1 deletion garak/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
def load_generator(
model_name: str, model_type: str, generations: int = 10
) -> Generator:
if model_type in ("openai", "replicate", "ggml", "huggingface") and not model_name:
if (
model_type in ("openai", "replicate", "ggml", "huggingface", "litellm")
and not model_name
):
message = f"⚠️ Model type '{model_type}' also needs a model name"
logger.error(message)
raise ValueError(message)
Expand Down
174 changes: 174 additions & 0 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""LiteLLM model support

Support for LiteLLM, which allows calling LLM APIs using the OpenAI format.

Depending on the model name provider, LiteLLM automatically
reads API keys from the respective environment variables.
(e.g. OPENAI_API_KEY for OpenAI models)

API key can also be directly set in the supplied generator json config.
This also enables support for any custom provider that follows the OAI format.

e.g Supply a JSON like this for Ollama's OAI api:
```json
{
"litellm.LiteLLMGenerator" : {
"api_base" : "http://localhost:11434/v1",
"provider" : "openai",
"api_key" : "test"
}
}
```

The above is an example of a config to connect LiteLLM with Ollama's OpenAI compatible API.

Then, when invoking garak, we pass it the path to the generator option file.

```
python -m garak --model_type litellm --model_name "phi" --generator_option_file ollama_base.json -p dan
```
"""

import logging

from os import getenv
from typing import List, Union

import backoff

import litellm

from garak import _config
from garak.generators.base import Generator

# Fix issue with Ollama which does not support `presence_penalty`
litellm.drop_params = True
# Suppress log messages from LiteLLM
litellm.verbose_logger.disabled = True
# litellm.set_verbose = True

# Based on the param support matrix below:
# https://docs.litellm.ai/docs/completion/input
# Some providers do not support the `n` parameter
# and thus cannot generate multiple completions in one request
unsupported_multiple_gen_providers = (
"openrouter/",
"claude",
"replicate/",
"bedrock",
"petals",
"palm/",
"together_ai/",
"text-bison",
"text-bison@001",
"chat-bison",
"chat-bison@001",
"chat-bison-32k",
"code-bison",
"code-bison@001",
"code-gecko@001",
"code-gecko@latest",
"codechat-bison",
"codechat-bison@001",
"codechat-bison-32k",
)


class LiteLLMGenerator(Generator):
"""Generator wrapper using LiteLLM to allow access to different
providers using the OpenAI API format.
"""

supports_multiple_generations = True
generator_family_name = "LiteLLM"

temperature = 0.7
top_p = 1.0
frequency_penalty = 0.0
presence_penalty = 0.0
stop = ["#", ";"]

def __init__(self, name: str, generations: int = 10):
self.name = name
self.fullname = f"LiteLLM {self.name}"
self.generations = generations
self.api_base = None
self.api_key = None
self.provider = None
self.supports_multiple_generations = not any(
self.name.startswith(provider)
for provider in unsupported_multiple_gen_providers
)

super().__init__(name, generations=generations)

if "litellm.LiteLLMGenerator" in _config.plugins.generators:
leondz marked this conversation as resolved.
Show resolved Hide resolved
for field in (
"api_key",
"provider",
"api_base",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
):
if field in _config.plugins.generators["litellm.LiteLLMGenerator"]:
setattr(
self,
field,
_config.plugins.generators["litellm.LiteLLMGenerator"][field],
)

if field == "provider" and self.api_key is None:
if self.provider == "openai":
self.api_key = getenv("OPENAI_API_KEY", None)
if self.api_key is None:
raise ValueError(
"Please supply an OpenAI API key in the OPENAI_API_KEY environment variable"
" or in the configuration file"
)
leondz marked this conversation as resolved.
Show resolved Hide resolved
else:
if field in ("provider"): # required fields here
raise ValueError(
"litellm generator needs to have a provider value configured - see docs"
)

@backoff.on_exception(backoff.fibo, Exception, max_value=70)
def _call_model(
self, prompt: Union[str, List[dict]]
) -> Union[List[str], str, None]:
if isinstance(prompt, str):
prompt = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list):
prompt = prompt
else:
msg = (
f"Expected a list of dicts for LiteLLM model {self.name}, but got {type(prompt)} instead. "
f"Returning nothing!"
)
logging.error(msg)
print(msg)
return list()

response = litellm.completion(
model=self.name,
messages=prompt,
temperature=self.temperature,
top_p=self.top_p,
n=self.generations,
stop=self.stop,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
api_base=self.api_base,
custom_llm_provider=self.provider,
api_key=self.api_key,
)

if self.supports_multiple_generations:
return [c.message.content for c in response.choices]
else:
return response.choices[0].message.content


default_class = "LiteLLMGenerator"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dependencies = [
"ecoji>=0.1.0",
"deepl==1.17.0",
"fschat>=0.2.36",
"litellm>=1.33.8",
"typing>=3.7,<3.8; python_version<'3.5'"
]

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ zalgolib>=0.2.2
ecoji>=0.1.0
deepl==1.17.0
fschat>=0.2.36
typing>=3.7,<3.8; python_version<'3.5'
litellm>=1.33.8
typing>=3.7,<3.8; python_version<'3.5'
45 changes: 45 additions & 0 deletions tests/generators/test_litellm.py
leondz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from os import getenv

from garak.generators.litellm import LiteLLMGenerator

DEFAULT_GENERATIONS_QTY = 10


@pytest.mark.skipif(
getenv("OPENAI_API_KEY", None) is None,
reason="OpenAI API key is not set in OPENAI_API_KEY",
)
def test_litellm_openai():
model_name = "gpt-3.5-turbo"
generator = LiteLLMGenerator(name=model_name)
assert generator.name == model_name
assert generator.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(generator.max_tokens, int)

output = generator.generate("How do I write a sonnet?")
assert len(output) == DEFAULT_GENERATIONS_QTY

for item in output:
assert isinstance(item, str)
print("test passed!")


@pytest.mark.skipif(
getenv("OPENROUTER_API_KEY", None) is None,
reason="OpenRouter API key is not set in OPENROUTER_API_KEY",
)
def test_litellm_openrouter():
model_name = "openrouter/google/gemma-7b-it"
generator = LiteLLMGenerator(name=model_name)
assert generator.name == model_name
assert generator.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(generator.max_tokens, int)

output = generator.generate("How do I write a sonnet?")
assert len(output) == DEFAULT_GENERATIONS_QTY

for item in output:
assert isinstance(item, str)
print("test passed!")
Loading