Skip to content

Commit

Permalink
Feature/litellm generator (#572)
Browse files Browse the repository at this point in the history
* build: add litellm to project requirements

* feat(litellm.py): implement litellm based garak generator

* fix: ensure cli validates that model name is provided for litellm generator

* fix(litellm.py): supress logs generated by litellm

logs from litellm were flooding the output, making it hard to see the logs from garak

* fix(litellm.py): fix custom provider in config not being  provided to litellm completion

* fix(litellm.py): raise error if openai api key not supplied when using openai as custom provider

* fix(litellm.py): fix bug where certain providers did not support multiple generations causing generations to be ignored

* test(test_litellm.py): add unit tests for LiteLLMGenerator

* fix(litellm.py): fix syntax error caused by typo

* docs(litellm.py): add docstring

* project toml syntax fix

* add to docs; reject if no provider set

* factor literal out to variable

Co-authored-by: Jeffrey Martin <jmartin@Op3n4M3.dev>

* factor literal out to variable

Co-authored-by: Jeffrey Martin <jmartin@Op3n4M3.dev>

* sync signature w base class

Co-authored-by: Jeffrey Martin <jmartin@Op3n4M3.dev>

* raise exception if litellm request but no config is set

* it's fine to run litellm without a config (in some cases)

---------

Co-authored-by: Leon Derczynski <leonderczynski@gmail.com>
Co-authored-by: Jeffrey Martin <jmartin@Op3n4M3.dev>
  • Loading branch information
3 people authored Apr 10, 2024
1 parent f43ea15 commit fde352c
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 3 deletions.
3 changes: 2 additions & 1 deletion garak/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ def main(arguments=[]) -> None:
# model is specified, we're doing something
elif _config.plugins.model_type:
if (
_config.plugins.model_type in ("openai", "replicate", "ggml", "huggingface")
_config.plugins.model_type
in ("openai", "replicate", "ggml", "huggingface", "litellm")
and not _config.plugins.model_name
):
message = f"⚠️ Model type '{_config.plugins.model_type}' also needs a model name\n You can set one with e.g. --model_name \"billwurtz/gpt-1.0\""
Expand Down
5 changes: 4 additions & 1 deletion garak/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
def load_generator(
model_name: str, model_type: str, generations: int = 10
) -> Generator:
if model_type in ("openai", "replicate", "ggml", "huggingface") and not model_name:
if (
model_type in ("openai", "replicate", "ggml", "huggingface", "litellm")
and not model_name
):
message = f"⚠️ Model type '{model_type}' also needs a model name"
logger.error(message)
raise ValueError(message)
Expand Down
174 changes: 174 additions & 0 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""LiteLLM model support
Support for LiteLLM, which allows calling LLM APIs using the OpenAI format.
Depending on the model name provider, LiteLLM automatically
reads API keys from the respective environment variables.
(e.g. OPENAI_API_KEY for OpenAI models)
API key can also be directly set in the supplied generator json config.
This also enables support for any custom provider that follows the OAI format.
e.g Supply a JSON like this for Ollama's OAI api:
```json
{
"litellm.LiteLLMGenerator" : {
"api_base" : "http://localhost:11434/v1",
"provider" : "openai",
"api_key" : "test"
}
}
```
The above is an example of a config to connect LiteLLM with Ollama's OpenAI compatible API.
Then, when invoking garak, we pass it the path to the generator option file.
```
python -m garak --model_type litellm --model_name "phi" --generator_option_file ollama_base.json -p dan
```
"""

import logging

from os import getenv
from typing import List, Union

import backoff

import litellm

from garak import _config
from garak.generators.base import Generator

# Fix issue with Ollama which does not support `presence_penalty`
litellm.drop_params = True
# Suppress log messages from LiteLLM
litellm.verbose_logger.disabled = True
# litellm.set_verbose = True

# Based on the param support matrix below:
# https://docs.litellm.ai/docs/completion/input
# Some providers do not support the `n` parameter
# and thus cannot generate multiple completions in one request
unsupported_multiple_gen_providers = (
"openrouter/",
"claude",
"replicate/",
"bedrock",
"petals",
"palm/",
"together_ai/",
"text-bison",
"text-bison@001",
"chat-bison",
"chat-bison@001",
"chat-bison-32k",
"code-bison",
"code-bison@001",
"code-gecko@001",
"code-gecko@latest",
"codechat-bison",
"codechat-bison@001",
"codechat-bison-32k",
)


class LiteLLMGenerator(Generator):
"""Generator wrapper using LiteLLM to allow access to different
providers using the OpenAI API format.
"""

supports_multiple_generations = True
generator_family_name = "LiteLLM"

temperature = 0.7
top_p = 1.0
frequency_penalty = 0.0
presence_penalty = 0.0
stop = ["#", ";"]

def __init__(self, name: str, generations: int = 10):
self.name = name
self.fullname = f"LiteLLM {self.name}"
self.generations = generations
self.api_base = None
self.api_key = None
self.provider = None
self.supports_multiple_generations = not any(
self.name.startswith(provider)
for provider in unsupported_multiple_gen_providers
)

super().__init__(name, generations=generations)

if "litellm.LiteLLMGenerator" in _config.plugins.generators:
for field in (
"api_key",
"provider",
"api_base",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
):
if field in _config.plugins.generators["litellm.LiteLLMGenerator"]:
setattr(
self,
field,
_config.plugins.generators["litellm.LiteLLMGenerator"][field],
)

if field == "provider" and self.api_key is None:
if self.provider == "openai":
self.api_key = getenv("OPENAI_API_KEY", None)
if self.api_key is None:
raise ValueError(
"Please supply an OpenAI API key in the OPENAI_API_KEY environment variable"
" or in the configuration file"
)
else:
if field in ("provider"): # required fields here
raise ValueError(
"litellm generator needs to have a provider value configured - see docs"
)

@backoff.on_exception(backoff.fibo, Exception, max_value=70)
def _call_model(
self, prompt: Union[str, List[dict]]
) -> Union[List[str], str, None]:
if isinstance(prompt, str):
prompt = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list):
prompt = prompt
else:
msg = (
f"Expected a list of dicts for LiteLLM model {self.name}, but got {type(prompt)} instead. "
f"Returning nothing!"
)
logging.error(msg)
print(msg)
return list()

response = litellm.completion(
model=self.name,
messages=prompt,
temperature=self.temperature,
top_p=self.top_p,
n=self.generations,
stop=self.stop,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
api_base=self.api_base,
custom_llm_provider=self.provider,
api_key=self.api_key,
)

if self.supports_multiple_generations:
return [c.message.content for c in response.choices]
else:
return response.choices[0].message.content


default_class = "LiteLLMGenerator"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dependencies = [
"ecoji>=0.1.0",
"deepl==1.17.0",
"fschat>=0.2.36",
"litellm>=1.33.8",
"typing>=3.7,<3.8; python_version<'3.5'"
]

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ zalgolib>=0.2.2
ecoji>=0.1.0
deepl==1.17.0
fschat>=0.2.36
typing>=3.7,<3.8; python_version<'3.5'
litellm>=1.33.8
typing>=3.7,<3.8; python_version<'3.5'
45 changes: 45 additions & 0 deletions tests/generators/test_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from os import getenv

from garak.generators.litellm import LiteLLMGenerator

DEFAULT_GENERATIONS_QTY = 10


@pytest.mark.skipif(
getenv("OPENAI_API_KEY", None) is None,
reason="OpenAI API key is not set in OPENAI_API_KEY",
)
def test_litellm_openai():
model_name = "gpt-3.5-turbo"
generator = LiteLLMGenerator(name=model_name)
assert generator.name == model_name
assert generator.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(generator.max_tokens, int)

output = generator.generate("How do I write a sonnet?")
assert len(output) == DEFAULT_GENERATIONS_QTY

for item in output:
assert isinstance(item, str)
print("test passed!")


@pytest.mark.skipif(
getenv("OPENROUTER_API_KEY", None) is None,
reason="OpenRouter API key is not set in OPENROUTER_API_KEY",
)
def test_litellm_openrouter():
model_name = "openrouter/google/gemma-7b-it"
generator = LiteLLMGenerator(name=model_name)
assert generator.name == model_name
assert generator.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(generator.max_tokens, int)

output = generator.generate("How do I write a sonnet?")
assert len(output) == DEFAULT_GENERATIONS_QTY

for item in output:
assert isinstance(item, str)
print("test passed!")

0 comments on commit fde352c

Please sign in to comment.