Skip to content

Commit

Permalink
Vectorize OpenAI model calls
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 6, 2023
1 parent 47d1459 commit b858751
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 40 deletions.
2 changes: 1 addition & 1 deletion outlines/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
codebase.
"""
from . import image_generation, text_completion
from . import embeddings, image_generation, text_completion
from .hf_diffusers import HuggingFaceDiffuser
from .hf_transformers import HuggingFaceCompletion
from .openai import OpenAICompletion, OpenAIEmbeddings, OpenAIImageGeneration
83 changes: 44 additions & 39 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Integration with OpenAI's API."""
import base64
import functools
import os
import warnings
from io import BytesIO
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Union

import numpy as np
from PIL import Image
from PIL.Image import Image as PILImage

import outlines
from outlines.caching import cache

__all__ = [
Expand Down Expand Up @@ -57,18 +59,22 @@ def OpenAICompletion(
f"The model {model_name} requested is not available. Only the completion and chat completion models are available for OpenAI."
)

def generate(prompt: str, *, samples=1, stop_at=None, is_in=None, type=None):
def generate(
prompt: str,
*,
samples=1,
stop_at: List[Optional[str]] = [],
is_in=None,
type=None,
):
import tiktoken

if stop_at is not None:
stop_at = tuple(stop_at)

mask = {}
if type is not None:
encoder = tiktoken.encoding_for_model(model_name)
mask = create_type_mask(type, encoder)

if is_in is not None and stop_at is not None:
if is_in is not None and stop_at:
raise TypeError("You cannot set `is_in` and `stop_at` at the same time.")
elif is_in is not None and len(mask) > 0:
raise TypeError("You cannot set `is_in` and `mask` at the same time.")
Expand All @@ -77,10 +83,11 @@ def generate(prompt: str, *, samples=1, stop_at=None, is_in=None, type=None):
else:
return generate_base(prompt, stop_at, samples, mask)

def generate_base(
prompt: str, stop_at: Optional[Tuple[str]], samples: int, mask: Dict[int, int]
@functools.partial(outlines.vectorize, signature="(),(m),(),()->(s)")
async def generate_base(
prompt: str, stop_at: List[Optional[str]], samples: int, mask: Dict[int, int]
) -> str:
responses = call_api(
responses = await call_api(
model_name,
format_prompt(prompt),
max_tokens,
Expand All @@ -91,13 +98,16 @@ def generate_base(
)

if samples == 1:
results = extract_choice(responses["choices"][0])
results = np.array([extract_choice(responses["choices"][0])])
else:
results = [extract_choice(responses["choices"][i]) for i in range(samples)]
results = np.array(
[extract_choice(responses["choices"][i]) for i in range(samples)]
)

return results

def generate_choice(
@functools.partial(outlines.vectorize, signature="(),(m),()->(s)")
async def generate_choice(
prompt: str, is_in: List[str], samples: int
) -> Union[List[str], str]:
"""Generate a sequence that must be one of many options.
Expand Down Expand Up @@ -130,12 +140,12 @@ def generate_choice(
if len(mask) == 0:
break

response = call_api(
response = await call_api(
model_name,
format_prompt(prompt),
1,
temperature,
None,
[],
mask,
samples,
)
Expand All @@ -144,10 +154,7 @@ def generate_choice(

decoded_samples.append("".join(decoded))

if samples == 1:
return decoded_samples[0]

return decoded_samples
return np.array(decoded_samples)

return generate

Expand Down Expand Up @@ -180,11 +187,12 @@ def call_embeddings_api(

response = openai.Embedding.create(
model=model,
input=input,
input=list(input),
)

return response

@functools.partial(outlines.vectorize, signature="()->(s)")
def generate(query: str) -> np.ndarray:
api_response = call_embeddings_api(model_name, query)
response = api_response["data"][0]["embedding"]
Expand Down Expand Up @@ -216,28 +224,25 @@ def OpenAIImageGeneration(model_name: str = "", size: str = "512x512"):

@error_handler
@cache
def call_image_generation_api(prompt: str, size: str, samples: int):
async def call_image_generation_api(prompt: str, size: str, samples: int):
import openai

response = openai.Image.create(
prompt=prompt, size=size, n=samples, response_format="b64_json"
response = await openai.Image.acreate(
prompt=prompt, size=size, n=int(samples), response_format="b64_json"
)

return response

def generate(prompt: str, samples: int = 1) -> PILImage:
api_response = call_image_generation_api(prompt, size, samples)

if samples == 1:
response = api_response["data"][0]["b64_json"]
return Image.open(BytesIO(base64.b64decode(response)))
@functools.partial(outlines.vectorize, signature="(),()->(s)")
async def generate(prompt: str, samples: int = 1) -> PILImage:
api_response = await call_image_generation_api(prompt, size, samples)

images = []
for i in range(samples):
response = api_response["data"][i]["b64_json"]
images.append(Image.open(BytesIO(base64.b64decode(response))))

return images
return np.array(images, dtype="object")

return generate

Expand Down Expand Up @@ -335,51 +340,51 @@ def call(*args, **kwargs):

@error_handler
@cache
def call_completion_api(
async def call_completion_api(
model: str,
prompt: str,
max_tokens: int,
temperature: float,
stop_sequences: Tuple[str],
stop_sequences: List[str],
logit_bias: Dict[str, int],
num_samples: int,
):
import openai

response = openai.Completion.create(
response = await openai.Completion.acreate(
engine=model,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=stop_sequences,
stop=list(stop_sequences) if len(stop_sequences) > 0 else None,
logit_bias=logit_bias,
n=num_samples,
n=int(num_samples),
)

return response


@error_handler
@cache
def call_chat_completion_api(
async def call_chat_completion_api(
model: str,
messages: List[Dict[str, str]],
max_tokens: int,
temperature: float,
stop_sequences: Tuple[str],
stop_sequences: List[str],
logit_bias: Dict[str, int],
num_samples: int,
):
import openai

response = openai.ChatCompletion.create(
response = await openai.ChatCompletion.acreate(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
stop=stop_sequences,
stop=list(stop_sequences) if len(stop_sequences) > 0 else None,
logit_bias=logit_bias,
n=num_samples,
n=int(num_samples),
)

return response

0 comments on commit b858751

Please sign in to comment.