Skip to content

Commit

Permalink
fix(dspy): Experiment with adding image data with GPT-4o
Browse files Browse the repository at this point in the history
  • Loading branch information
dat-boris committed Jun 8, 2024
1 parent 16432dd commit 2ec86d8
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 26 deletions.
16 changes: 14 additions & 2 deletions dsp/modules/google.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import base64
import os
from collections.abc import Iterable
from typing import Any, Optional

import backoff

from dsp.modules.lm import LM
from dsp.modules.lm import IMG_DATA_KEY, LM

try:
import google.generativeai as genai
Expand Down Expand Up @@ -118,10 +119,21 @@ def basic_request(self, prompt: str, **kwargs):
if n is not None and n > 1 and kwargs['temperature'] == 0.0:
kwargs['temperature'] = 0.7

response = self.llm.generate_content(prompt, generation_config=kwargs)
content = [prompt]
base64_data = ""
if IMG_DATA_KEY in kwargs:
base64_data = kwargs.pop(IMG_DATA_KEY)
# convert base64 string to bytes
content.append({
"mime_type": "image/jpeg",
"data": base64.b64decode(base64_data),
})

response = self.llm.generate_content(content, generation_config=kwargs)

history = {
"prompt": prompt,
"image": base64_data,
"response": [response],
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
Expand Down
22 changes: 19 additions & 3 deletions dsp/modules/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import openai

from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
from dsp.modules.lm import LM
from dsp.modules.lm import IMG_DATA_KEY, LM

try:
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
Expand Down Expand Up @@ -108,10 +108,26 @@ def basic_request(self, prompt: str, **kwargs):

kwargs = {**self.kwargs, **kwargs}
if self.model_type == "chat":
# caching mechanism requires hashable kwargs
messages = [{"role": "user", "content": prompt}]
content = prompt
# Add image data if provided in kwargs
if IMG_DATA_KEY in kwargs:
img_data = kwargs.pop(IMG_DATA_KEY)
content = [
{
"type": "text",
"text": prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_data}",
},
},
]
messages = [{"role": "user", "content": content}]
if self.system_prompt:
messages.insert(0, {"role": "system", "content": self.system_prompt})
# caching mechanism requires hashable kwargs
kwargs["messages"] = messages
kwargs = {"stringify_request": json.dumps(kwargs)}
response = chat_request(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from abc import ABC, abstractmethod

# The key used to pass image data to the language model.
# TODO: maybe change the basic_request signature to do this?
IMG_DATA_KEY = "img_data"

class LM(ABC):
"""Abstract class for language models."""
Expand Down
17 changes: 13 additions & 4 deletions dsp/primitives/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Optional

import dsp
from dsp.modules.lm import IMG_DATA_KEY
from dsp.primitives.demonstrate import Example
from dsp.templates.template_v3 import Template
from dsp.utils import normalize_text, zipstar
Expand Down Expand Up @@ -69,16 +70,16 @@ def extend_generation(completion: Example, field_names: list[str], stage:str, ma
completion.pop(field_name, None)

# Recurse with greedy decoding and a shorter length.
max_tokens = (kwargs.get("max_tokens") or
max_tokens = (kwargs.get("max_tokens") or
kwargs.get("max_output_tokens") or
dsp.settings.lm.kwargs.get("max_tokens") or
dsp.settings.lm.kwargs.get("max_tokens") or
dsp.settings.lm.kwargs.get('max_output_tokens'))


if max_tokens is None:
raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.")
max_tokens = min(max(75, max_tokens // 2), max_tokens)
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens"
new_kwargs = {
**kwargs,
Expand All @@ -94,7 +95,7 @@ def extend_generation(completion: Example, field_names: list[str], stage:str, ma
original_example=original_example,
)
return finished_completion.data[0]


def do_generate(
example: Example, stage: str, max_depth: int = 2, original_example=None,
Expand All @@ -109,6 +110,14 @@ def do_generate(

# Generate and extract the fields.
prompt = template(example)

# For image data, we need to embed that in the kwargs to be passed
# to the generator and the LM.
for field in template.fields:
if field.is_image:
kwargs[IMG_DATA_KEY] = example[field.input_variable]
break

completions: list[dict[str, Any]] = generator(prompt, **kwargs)
completions: list[Example] = [template.extract(example, p) for p in completions]

Expand Down
9 changes: 8 additions & 1 deletion dsp/templates/template_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .utils import format_answers, passages2text

Field = namedtuple("Field", "name separator input_variable output_variable description")
Field = namedtuple("Field", "name separator input_variable output_variable description is_image")

# TODO: de-duplicate with dsp/templates/template.py

Expand Down Expand Up @@ -63,6 +63,8 @@ def __init__(
input_variable=input_variable,
output_variable=output_variable,
description=description,
# See template_v3 for handling image fields
is_image=False,
),
)

Expand Down Expand Up @@ -94,6 +96,11 @@ def query(self, example: Example, is_demo: bool = False) -> str:

for field in self.fields:
if field.input_variable in example and example[field.input_variable] is not None:
if field.is_image:
# For image fields, we do not include this in the prompt
# since we want to pass this straight to the LM
continue

if field.input_variable in self.format_handlers:
format_handler = self.format_handlers[field.input_variable]
else:
Expand Down
8 changes: 4 additions & 4 deletions dsp/templates/template_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def __init__(self, instructions: str, **kwargs):
input_variable=key,
output_variable=key,
separator=separator,
is_image=bool(value.is_image),
)
self.fields.append(field)

if value.format:
self.format_handlers[key] = value.format



# equality
def __eq__(self, other):
if set(self.kwargs.keys()) != set(other.kwargs.keys()):
Expand All @@ -62,7 +62,7 @@ def __eq__(self, other):
if not v1 == v2:
print(k, v1, v2)


# print("here?", self.instructions == other.instructions, self.kwargs == other.kwargs)
return self.instructions == other.instructions and self.kwargs == other.kwargs

Expand All @@ -71,4 +71,4 @@ def __str__(self) -> str:
field_names = [field.name for field in self.fields]

return f"Template({self.instructions}, {field_names})"

14 changes: 8 additions & 6 deletions dspy/signatures/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# The following arguments can be used in DSPy InputField and OutputField in addition
# to the standard pydantic.Field arguments. We just hope pydanitc doesn't add these,
# as it would give a name clash.
DSPY_FIELD_ARG_NAMES = ["desc", "prefix", "format", "parser", "__dspy_field_type"]
DSPY_FIELD_ARG_NAMES = ["desc", "prefix", "format", "is_image", "parser", "__dspy_field_type"]


def move_kwargs(**kwargs):
Expand Down Expand Up @@ -39,16 +39,18 @@ def new_to_old_field(field):
prefix=field.json_schema_extra["prefix"],
desc=field.json_schema_extra["desc"],
format=field.json_schema_extra.get("format"),
is_image=field.json_schema_extra.get("is_image"),
)


class OldField:
"""A more ergonomic datatype that infers prefix and desc if omitted."""

def __init__(self, *, prefix=None, desc=None, input, format=None):
def __init__(self, *, prefix=None, desc=None, input, format=None, is_image=False):
self.prefix = prefix # This can be None initially and set later
self.desc = desc
self.format = format
self.is_image = is_image

def finalize(self, key, inferred_prefix):
"""Set the prefix if it's not provided explicitly."""
Expand All @@ -66,10 +68,10 @@ def __eq__(self, __value: object) -> bool:


class OldInputField(OldField):
def __init__(self, *, prefix=None, desc=None, format=None):
super().__init__(prefix=prefix, desc=desc, input=True, format=format)
def __init__(self, *, prefix=None, desc=None, format=None, is_image=False):
super().__init__(prefix=prefix, desc=desc, input=True, format=format, is_image=is_image)


class OldOutputField(OldField):
def __init__(self, *, prefix=None, desc=None, format=None):
super().__init__(prefix=prefix, desc=desc, input=False, format=format)
def __init__(self, *, prefix=None, desc=None, format=None, is_image=False):
super().__init__(prefix=prefix, desc=desc, input=False, format=format, is_image=is_image)
27 changes: 21 additions & 6 deletions dspy/teleprompt/mipro_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import sys
import textwrap
from collections import defaultdict
from typing import Any
from typing import Any, Callable, Optional

import optuna

import dsp
import dspy
from dsp.templates.template_v2 import Field
from dspy.evaluate.evaluate import Evaluate
from dspy.signatures import Signature
from dspy.signatures.signature import signature_to_template
Expand Down Expand Up @@ -38,7 +39,7 @@
* init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.0.
* verbose: Tells the method whether or not to print intermediate steps.
* track_stats: Tells the method whether or not to track statistics about the optimization process.
If True, the method will track a dictionary with a key corresponding to the trial number,
If True, the method will track a dictionary with a key corresponding to the trial number,
and a value containing a dict with the following keys:
* program: the program being evaluated at a given trial
* score: the last average evaluated score for the program
Expand Down Expand Up @@ -131,6 +132,8 @@ class DatasetDescriptorWithPriorObservations(dspy.Signature):
desc="Somethings that holds true for most or all of the data you observed or COMPLETE if you have nothing to add",
)

ExampleStringifyFn = Callable[[dsp.Example, list[str]], str]


class MIPRO(Teleprompter):
def __init__(
Expand All @@ -144,6 +147,7 @@ def __init__(
verbose=False,
track_stats=True,
view_data_batch_size=10,
example_stringify_fn: Optional[ExampleStringifyFn] = None,
):
self.num_candidates = num_candidates
self.metric = metric
Expand All @@ -154,6 +158,7 @@ def __init__(
self.track_stats = track_stats
self.teacher_settings = teacher_settings
self.view_data_batch_size = view_data_batch_size
self.example_stringify_fn = example_stringify_fn

def _print_full_program(self, program):
for i, predictor in enumerate(program.predictors()):
Expand All @@ -174,7 +179,12 @@ def _print_model_history(self, model, n=1):

def _observe_data(self, trainset, max_iterations=10):
upper_lim = min(len(trainset), self.view_data_batch_size)
observation = dspy.Predict(DatasetDescriptor, n=1, temperature=1.0)(examples=(trainset[0:upper_lim].__repr__()))
example_stringify_fn = self.example_stringify_fn
if example_stringify_fn is None:
example_stringify_fn = lambda e: repr(e)
observation = dspy.Predict(DatasetDescriptor, n=1, temperature=1.0)(examples=(
"\n".join([example_stringify_fn(e) for e in trainset[0:upper_lim]])
))
observations = observation["observations"]

skips = 0
Expand All @@ -183,7 +193,7 @@ def _observe_data(self, trainset, max_iterations=10):
upper_lim = min(len(trainset), b + self.view_data_batch_size)
output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)(
prior_observations=observations,
examples=(trainset[b:upper_lim].__repr__()),
examples="\n".join([example_stringify_fn(e) for e in trainset[0:upper_lim]]),
)
iterations += 1
if len(output["observations"]) >= 8 and output["observations"][:8].upper() == "COMPLETE":
Expand All @@ -199,8 +209,13 @@ def _observe_data(self, trainset, max_iterations=10):

return summary.summary

def _create_example_string(self, fields, example):
def _create_example_string(self, fields:list[Field], example: dsp.Example) -> str:
# Building the output string

if self.example_stringify_fn is not None:
field_names = [field.input_variable for field in fields]
return self.example_stringify_fn(example, field_names)

output = []
for field in fields:
name = field.name
Expand Down Expand Up @@ -383,7 +398,7 @@ def compile(
{YELLOW}{BOLD}Estimated Cost Calculation:{ENDC}
{YELLOW}Total Cost = (Number of calls to task model * (Avg Input Token Length per Call * Task Model Price per Input Token + Avg Output Token Length per Call * Task Model Price per Output Token)
{YELLOW}Total Cost = (Number of calls to task model * (Avg Input Token Length per Call * Task Model Price per Input Token + Avg Output Token Length per Call * Task Model Price per Output Token)
+ (Number of calls to prompt model * (Avg Input Token Length per Call * Task Prompt Price per Input Token + Avg Output Token Length per Call * Prompt Model Price per Output Token).{ENDC}
For a preliminary estimate of potential costs, we recommend you perform your own calculations based on the task
Expand Down

0 comments on commit 2ec86d8

Please sign in to comment.