Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dspy): Experiment with adding image data with GPT-4o and Gemini #1099

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions dsp/modules/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import openai

from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
from dsp.modules.lm import LM
from dsp.modules.lm import IMG_DATA_KEY, LM

try:
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
Expand Down Expand Up @@ -134,8 +134,23 @@ def basic_request(self, prompt: str, **kwargs):

kwargs = {**self.kwargs, **kwargs}
if self.model_type == "chat":
# caching mechanism requires hashable kwargs
messages = [{"role": "user", "content": prompt}]
content = prompt
# Add image data if provided in kwargs
if IMG_DATA_KEY in kwargs:
img_data = kwargs.pop(IMG_DATA_KEY)
content = [
{
"type": "text",
"text": prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_data}",
},
},
]
messages = [{"role": "user", "content": content}]
if self.system_prompt:
messages.insert(0, {"role": "system", "content": self.system_prompt})

Expand Down
16 changes: 14 additions & 2 deletions dsp/modules/google.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import base64
import os
from collections.abc import Iterable
from typing import Any, Optional

import backoff

from dsp.modules.lm import LM
from dsp.modules.lm import IMG_DATA_KEY, LM

try:
import google.generativeai as genai
Expand Down Expand Up @@ -118,10 +119,21 @@ def basic_request(self, prompt: str, **kwargs):
if n is not None and n > 1 and kwargs['temperature'] == 0.0:
kwargs['temperature'] = 0.7

response = self.llm.generate_content(prompt, generation_config=kwargs)
content = [prompt]
base64_data = ""
if IMG_DATA_KEY in kwargs:
base64_data = kwargs.pop(IMG_DATA_KEY)
# convert base64 string to bytes
content.append({
"mime_type": "image/jpeg",
"data": base64.b64decode(base64_data),
})

response = self.llm.generate_content(content, generation_config=kwargs)

history = {
"prompt": prompt,
"image": base64_data,
"response": [response],
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
Expand Down
22 changes: 19 additions & 3 deletions dsp/modules/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import openai

from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
from dsp.modules.lm import LM
from dsp.modules.lm import IMG_DATA_KEY, LM

try:
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
Expand Down Expand Up @@ -108,10 +108,26 @@ def basic_request(self, prompt: str, **kwargs):

kwargs = {**self.kwargs, **kwargs}
if self.model_type == "chat":
# caching mechanism requires hashable kwargs
messages = [{"role": "user", "content": prompt}]
content = prompt
# Add image data if provided in kwargs
if IMG_DATA_KEY in kwargs:
img_data = kwargs.pop(IMG_DATA_KEY)
content = [
{
"type": "text",
"text": prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_data}",
},
},
]
messages = [{"role": "user", "content": content}]
if self.system_prompt:
messages.insert(0, {"role": "system", "content": self.system_prompt})
# caching mechanism requires hashable kwargs
kwargs["messages"] = messages
kwargs = {"stringify_request": json.dumps(kwargs)}
response = chat_request(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from abc import ABC, abstractmethod

# The key used to pass image data to the language model.
# TODO: maybe change the basic_request signature to do this?
IMG_DATA_KEY = "img_data"

class LM(ABC):
"""Abstract class for language models."""
Expand Down
17 changes: 13 additions & 4 deletions dsp/primitives/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Optional

import dsp
from dsp.modules.lm import IMG_DATA_KEY
from dsp.primitives.demonstrate import Example
from dsp.templates.template_v3 import Template
from dsp.utils import normalize_text, zipstar
Expand Down Expand Up @@ -69,16 +70,16 @@ def extend_generation(completion: Example, field_names: list[str], stage:str, ma
completion.pop(field_name, None)

# Recurse with greedy decoding and a shorter length.
max_tokens = (kwargs.get("max_tokens") or
max_tokens = (kwargs.get("max_tokens") or
kwargs.get("max_output_tokens") or
dsp.settings.lm.kwargs.get("max_tokens") or
dsp.settings.lm.kwargs.get("max_tokens") or
dsp.settings.lm.kwargs.get('max_output_tokens'))


if max_tokens is None:
raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.")
max_tokens = min(max(75, max_tokens // 2), max_tokens)
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens"
new_kwargs = {
**kwargs,
Expand All @@ -94,7 +95,7 @@ def extend_generation(completion: Example, field_names: list[str], stage:str, ma
original_example=original_example,
)
return finished_completion.data[0]


def do_generate(
example: Example, stage: str, max_depth: int = 2, original_example=None,
Expand All @@ -109,6 +110,14 @@ def do_generate(

# Generate and extract the fields.
prompt = template(example)

# For image data, we need to embed that in the kwargs to be passed
# to the generator and the LM.
for field in template.fields:
if field.is_image:
kwargs[IMG_DATA_KEY] = example[field.input_variable]
break

completions: list[dict[str, Any]] = generator(prompt, **kwargs)
completions: list[Example] = [template.extract(example, p) for p in completions]

Expand Down
9 changes: 8 additions & 1 deletion dsp/templates/template_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .utils import format_answers, passages2text

Field = namedtuple("Field", "name separator input_variable output_variable description")
Field = namedtuple("Field", "name separator input_variable output_variable description is_image")

# TODO: de-duplicate with dsp/templates/template.py

Expand Down Expand Up @@ -63,6 +63,8 @@ def __init__(
input_variable=input_variable,
output_variable=output_variable,
description=description,
# See template_v3 for handling image fields
is_image=False,
),
)

Expand Down Expand Up @@ -94,6 +96,11 @@ def query(self, example: Example, is_demo: bool = False) -> str:

for field in self.fields:
if field.input_variable in example and example[field.input_variable] is not None:
if field.is_image:
# For image fields, we do not include this in the prompt
# since we want to pass this straight to the LM
continue

if field.input_variable in self.format_handlers:
format_handler = self.format_handlers[field.input_variable]
else:
Expand Down
8 changes: 4 additions & 4 deletions dsp/templates/template_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def __init__(self, instructions: str, **kwargs):
input_variable=key,
output_variable=key,
separator=separator,
is_image=bool(value.is_image),
)
self.fields.append(field)

if value.format:
self.format_handlers[key] = value.format



# equality
def __eq__(self, other):
if set(self.kwargs.keys()) != set(other.kwargs.keys()):
Expand All @@ -62,7 +62,7 @@ def __eq__(self, other):
if not v1 == v2:
print(k, v1, v2)


# print("here?", self.instructions == other.instructions, self.kwargs == other.kwargs)
return self.instructions == other.instructions and self.kwargs == other.kwargs

Expand All @@ -71,4 +71,4 @@ def __str__(self) -> str:
field_names = [field.name for field in self.fields]

return f"Template({self.instructions}, {field_names})"

14 changes: 8 additions & 6 deletions dspy/signatures/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# The following arguments can be used in DSPy InputField and OutputField in addition
# to the standard pydantic.Field arguments. We just hope pydanitc doesn't add these,
# as it would give a name clash.
DSPY_FIELD_ARG_NAMES = ["desc", "prefix", "format", "parser", "__dspy_field_type"]
DSPY_FIELD_ARG_NAMES = ["desc", "prefix", "format", "is_image", "parser", "__dspy_field_type"]


def move_kwargs(**kwargs):
Expand Down Expand Up @@ -39,16 +39,18 @@ def new_to_old_field(field):
prefix=field.json_schema_extra["prefix"],
desc=field.json_schema_extra["desc"],
format=field.json_schema_extra.get("format"),
is_image=field.json_schema_extra.get("is_image"),
)


class OldField:
"""A more ergonomic datatype that infers prefix and desc if omitted."""

def __init__(self, *, prefix=None, desc=None, input, format=None):
def __init__(self, *, prefix=None, desc=None, input, format=None, is_image=False):
self.prefix = prefix # This can be None initially and set later
self.desc = desc
self.format = format
self.is_image = is_image

def finalize(self, key, inferred_prefix):
"""Set the prefix if it's not provided explicitly."""
Expand All @@ -66,10 +68,10 @@ def __eq__(self, __value: object) -> bool:


class OldInputField(OldField):
def __init__(self, *, prefix=None, desc=None, format=None):
super().__init__(prefix=prefix, desc=desc, input=True, format=format)
def __init__(self, *, prefix=None, desc=None, format=None, is_image=False):
super().__init__(prefix=prefix, desc=desc, input=True, format=format, is_image=is_image)


class OldOutputField(OldField):
def __init__(self, *, prefix=None, desc=None, format=None):
super().__init__(prefix=prefix, desc=desc, input=False, format=format)
def __init__(self, *, prefix=None, desc=None, format=None, is_image=False):
super().__init__(prefix=prefix, desc=desc, input=False, format=format, is_image=is_image)
27 changes: 21 additions & 6 deletions dspy/teleprompt/mipro_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import sys
import textwrap
from collections import defaultdict
from typing import Any
from typing import Any, Callable, Optional

import optuna

import dsp
import dspy
from dsp.templates.template_v2 import Field
from dspy.evaluate.evaluate import Evaluate
from dspy.signatures import Signature
from dspy.signatures.signature import signature_to_template
Expand Down Expand Up @@ -38,7 +39,7 @@
* init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.0.
* verbose: Tells the method whether or not to print intermediate steps.
* track_stats: Tells the method whether or not to track statistics about the optimization process.
If True, the method will track a dictionary with a key corresponding to the trial number,
If True, the method will track a dictionary with a key corresponding to the trial number,
and a value containing a dict with the following keys:
* program: the program being evaluated at a given trial
* score: the last average evaluated score for the program
Expand Down Expand Up @@ -131,6 +132,8 @@ class DatasetDescriptorWithPriorObservations(dspy.Signature):
desc="Somethings that holds true for most or all of the data you observed or COMPLETE if you have nothing to add",
)

ExampleStringifyFn = Callable[[dsp.Example, list[str]], str]


class MIPRO(Teleprompter):
def __init__(
Expand All @@ -144,6 +147,7 @@ def __init__(
verbose=False,
track_stats=True,
view_data_batch_size=10,
example_stringify_fn: Optional[ExampleStringifyFn] = None,
):
self.num_candidates = num_candidates
self.metric = metric
Expand All @@ -154,6 +158,7 @@ def __init__(
self.track_stats = track_stats
self.teacher_settings = teacher_settings
self.view_data_batch_size = view_data_batch_size
self.example_stringify_fn = example_stringify_fn

def _print_full_program(self, program):
for i, predictor in enumerate(program.predictors()):
Expand All @@ -174,7 +179,12 @@ def _print_model_history(self, model, n=1):

def _observe_data(self, trainset, max_iterations=10):
upper_lim = min(len(trainset), self.view_data_batch_size)
observation = dspy.Predict(DatasetDescriptor, n=1, temperature=1.0)(examples=(trainset[0:upper_lim].__repr__()))
example_stringify_fn = self.example_stringify_fn
if example_stringify_fn is None:
example_stringify_fn = lambda e: repr(e)
observation = dspy.Predict(DatasetDescriptor, n=1, temperature=1.0)(examples=(
"\n".join([example_stringify_fn(e) for e in trainset[0:upper_lim]])
))
observations = observation["observations"]

skips = 0
Expand All @@ -183,7 +193,7 @@ def _observe_data(self, trainset, max_iterations=10):
upper_lim = min(len(trainset), b + self.view_data_batch_size)
output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)(
prior_observations=observations,
examples=(trainset[b:upper_lim].__repr__()),
examples="\n".join([example_stringify_fn(e) for e in trainset[b:upper_lim]]),
)
iterations += 1
if len(output["observations"]) >= 8 and output["observations"][:8].upper() == "COMPLETE":
Expand All @@ -199,8 +209,13 @@ def _observe_data(self, trainset, max_iterations=10):

return summary.summary

def _create_example_string(self, fields, example):
def _create_example_string(self, fields:list[Field], example: dsp.Example) -> str:
# Building the output string

if self.example_stringify_fn is not None:
field_names = [field.input_variable for field in fields]
return self.example_stringify_fn(example, field_names)

output = []
for field in fields:
name = field.name
Expand Down Expand Up @@ -383,7 +398,7 @@ def compile(

{YELLOW}{BOLD}Estimated Cost Calculation:{ENDC}

{YELLOW}Total Cost = (Number of calls to task model * (Avg Input Token Length per Call * Task Model Price per Input Token + Avg Output Token Length per Call * Task Model Price per Output Token)
{YELLOW}Total Cost = (Number of calls to task model * (Avg Input Token Length per Call * Task Model Price per Input Token + Avg Output Token Length per Call * Task Model Price per Output Token)
+ (Number of calls to prompt model * (Avg Input Token Length per Call * Task Prompt Price per Input Token + Avg Output Token Length per Call * Prompt Model Price per Output Token).{ENDC}

For a preliminary estimate of potential costs, we recommend you perform your own calculations based on the task
Expand Down
Loading