Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(dspy): added compatibility with ANY llamaindex LLM #1225

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 335 additions & 0 deletions dsp/modules/llamaindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
"""
this code implements a wrapper around the llama_index library to emulate a dspy llm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel free to add these comments into documentation for the LlamaIndexWrapper LM


this allows the llama_index library to be used in the dspy framework since dspy has limited support for LLMs

This code is a slightly modified copy of dspy/dsp/modules/azure_openai.py

The way this works is simply by creating a dummy openai client that wraps around any llama_index LLM object and implements .complete and .chat

tested with python 3.12

dspy==0.1.4
dspy-ai==2.4.9
llama-index==0.10.35
llama-index-llms-openai==0.1.18

"""

import json
import logging
from typing import Any, Literal

from easydict import EasyDict
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.llms import LLM


def LlamaIndexOpenAIClientWrapper(llm: LLM):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, does this support compatibility with all Llamaindex-LLM integrations? It seems to me like the definition here is only for OpenAI clients, which would leave this change redundant to just using DSPy.OpenAI. If it is not just isolated to OpenAI, lets refactor this function to LlamaIndexLLMWrapper

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is supposed to support all llamaindex llms. However the way i use this is by wrapping the llamaindex llm to look like an openai client object, then i pass this to a dspy azure openai llm.
Llamaindex llm.chat and llm.complete are used for the actual calls

def chat(messages: list[ChatMessage], **kwargs) -> Any:
return llm.chat([ChatMessage(**message) for message in messages], **kwargs)

def complete(prompt: str, **kwargs) -> Any:
return llm.complete(prompt, **kwargs)

client = EasyDict(
{
'chat': EasyDict({'completions': EasyDict({'create': chat})}),
'completion': EasyDict({'create': complete}),
'ChatCompletion': EasyDict({'create': chat}),
'Completion': EasyDict({'create': complete}),
}
)
return client


# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(message)s',
handlers=[logging.FileHandler('azure_openai_usage.log')],
)

import functools
import json
from typing import Any, Literal

import backoff
import dsp
import openai
from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
from dsp.modules.lm import LM

try:
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
except Exception:
OPENAI_LEGACY = True

try:
import openai.error
from openai.openai_object import OpenAIObject

ERRORS = (
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.APIError,
)
except Exception:
ERRORS = (openai.RateLimitError, openai.APIError)
OpenAIObject = dict


def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
'Backing off {wait:0.1f} seconds after {tries} tries ' 'calling function {target} with kwargs ' '{kwargs}'.format(**details),
)


class DspyLlamaIndexWrapper(LM):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor name to LlamaIndexWrapper. and please add to imports in lm.py as done with other modules

"""Wrapper around Azure's API for OpenAI.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the azure openai documentation still relevant here? please refactor if not


Args:
api_base (str): Azure URL endpoint for model calling, often called 'azure_endpoint'.
api_version (str): Version identifier for API.
model (str, optional): OpenAI or Azure supported LLM model to use. Defaults to "text-davinci-002".
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None.
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "chat".
**kwargs: Additional arguments to pass to the API provider.
"""

def __init__(
self,
llm: LLM,
model_type: Literal['chat', 'text'] = 'chat',
**kwargs,
):
super().__init__(llm.model)
self.provider = 'openai'

self.llm = llm
self.client = LlamaIndexOpenAIClientWrapper(llm)
model = llm.model
self.model_type = model_type

# if not OPENAI_LEGACY and "model" not in kwargs:
# if "deployment_id" in kwargs:
# kwargs["model"] = kwargs["deployment_id"]
# del kwargs["deployment_id"]

# if "api_version" in kwargs:
# del kwargs["api_version"]

if 'model' not in kwargs:
kwargs['model'] = model

self.kwargs = {
'temperature': 0.0,
'max_tokens': 150,
'top_p': 1,
'frequency_penalty': 0,
'presence_penalty': 0,
'n': 1,
**kwargs,
} # TODO: add kwargs above for </s>

self.history: list[dict[str, Any]] = []

def _openai_client(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor if the scope of this LM is not limited to just OpenAI

# if OPENAI_LEGACY:
# return openai

return self.client

def log_usage(self, response):
"""Log the total tokens from the Azure OpenAI API response."""
usage_data = response.get('usage')
if usage_data:
total_tokens = usage_data.get('total_tokens')
logging.info(f'{total_tokens}')

def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}
if self.model_type == 'chat':
# caching mechanism requires hashable kwargs
kwargs['messages'] = [{'role': 'user', 'content': prompt}]
kwargs = {'stringify_request': json.dumps(kwargs)}
# response = chat_request(self.client, **kwargs)
# if OPENAI_LEGACY:
# return _cached_gpt3_turbo_request_v2_wrapped(**kwargs)
# else:
return v1_chat_request(self.client, **kwargs)

else:
kwargs['prompt'] = prompt
response = self.completions_request(**kwargs)

history = {
'prompt': prompt,
'response': response,
'kwargs': kwargs,
'raw_kwargs': raw_kwargs,
}
self.history.append(history)

return response

@backoff.on_exception(
backoff.expo,
ERRORS,
max_time=1000,
on_backoff=backoff_hdlr,
)
def request(self, prompt: str, **kwargs):
"""Handles retrieval of GPT-3 completions whilst handling rate limiting and caching."""
if 'model_type' in kwargs:
del kwargs['model_type']

return self.basic_request(prompt, **kwargs)

def _get_choice_text(self, choice: dict[str, Any]) -> str:
if self.model_type == 'chat':
return choice['message']['content']
return choice['text']

def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Retrieves completions from OpenAI Model.

Args:
prompt (str): prompt to send to GPT-3
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.

Returns:
list[dict[str, Any]]: list of completion choices
"""

assert only_completed, 'for now'
assert return_sorted is False, 'for now'

response = self.request(prompt, **kwargs)

try:
if dsp.settings.log_openai_usage:
self.log_usage(response)
except Exception:
pass

choices = response['choices']

completed_choices = [c for c in choices if c['finish_reason'] != 'length']

if only_completed and len(completed_choices):
choices = completed_choices

completions = [self._get_choice_text(c) for c in choices]
if return_sorted and kwargs.get('n', 1) > 1:
scored_completions = []

for c in choices:
tokens, logprobs = (
c['logprobs']['tokens'],
c['logprobs']['token_logprobs'],
)

if '<|endoftext|>' in tokens:
index = tokens.index('<|endoftext|>') + 1
tokens, logprobs = tokens[:index], logprobs[:index]

avglog = sum(logprobs) / len(logprobs)
scored_completions.append((avglog, self._get_choice_text(c)))

scored_completions = sorted(scored_completions, reverse=True)
completions = [c for _, c in scored_completions]

return completions

def completions_request(self, **kwargs):
# if OPENAI_LEGACY:
# return cached_gpt3_request_v2_wrapped(**kwargs)
return v1_completions_request(self.client, **kwargs)


def v1_chat_request(client, **kwargs):
@functools.lru_cache(maxsize=None if cache_turn_on else 0)
@NotebookCacheMemory.cache
def v1_cached_gpt3_turbo_request_v2_wrapped(**kwargs):
@CacheMemory.cache
def v1_cached_gpt3_turbo_request_v2(**kwargs):
if 'stringify_request' in kwargs:
kwargs = json.loads(kwargs['stringify_request'])
return client.chat.completions.create(**kwargs)

return v1_cached_gpt3_turbo_request_v2(**kwargs)

response = v1_cached_gpt3_turbo_request_v2_wrapped(**kwargs)

try:
response = response.model_dump()
except Exception:
response = response.raw
response['choices'] = [json.loads(x.json()) for x in response['choices']]
response['usage'] = json.loads(response['usage'].json())
return response


def v1_completions_request(client, **kwargs):
@functools.lru_cache(maxsize=None if cache_turn_on else 0)
@NotebookCacheMemory.cache
def v1_cached_gpt3_request_v2_wrapped(**kwargs):
@CacheMemory.cache
def v1_cached_gpt3_request_v2(**kwargs):
return client.completions.create(**kwargs)

return v1_cached_gpt3_request_v2(**kwargs)

return v1_cached_gpt3_request_v2_wrapped(**kwargs).model_dump()


## ======== test =========

if __name__ == '__main__':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove and add to documentation for the LM

print('Testing DspyLlamaIndexWrapper')
import os

import dspy
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric
from llama_index.llms.openai import OpenAI

llm = OpenAI(api_key=os.environ['OPENAI_API_KEY'], model='gpt-3.5-turbo')
dspy_llm = DspyLlamaIndexWrapper(llm)

# Load math questions from the GSM8K dataset.
gsm8k = GSM8K()
gsm8k_trainset, gsm8k_devset = gsm8k.train[:10], gsm8k.dev[:10]

class CoT(dspy.Module):
def __init__(self):
super().__init__()
self.prog = dspy.ChainOfThought('question -> answer')

def forward(self, question):
response = self.prog(question=question)
return response

##

dspy.settings.configure(lm=dspy_llm)

from dspy.teleprompt import BootstrapFewShot

# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 4-shot examples of our CoT program.
config = dict(max_bootstrapped_demos=4, max_labeled_demos=4)

# Optimize! Use the `gsm8k_metric` here. In general, the metric is going to tell the optimizer how well it's doing.
teleprompter = BootstrapFewShot(metric=gsm8k_metric, **config)
optimized_cot = teleprompter.compile(CoT(), trainset=gsm8k_trainset)
print(f'{optimized_cot=}')