Skip to content

Commit

Permalink
Fix vectorize signature and caching in OpenAI integration
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 13, 2023
1 parent 1e0b0af commit add3743
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/math_generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def execute_code(code):


prompt = answer_with_code_prompt(question, examples)
answer = models.openai("gpt-4")(prompt)
answer = models.openai("gpt-3.5-turbo")(prompt)
result = execute_code(answer)
print(f"It takes Carla {result:.0f} minutes to download the file.")
38 changes: 23 additions & 15 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ def __call__(
)
)
if "gpt-" in self.config.model:
response, usage = generate_chat(
response, prompt_tokens, completion_tokens = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.prompt_tokens += usage["prompt_tokens"]
self.completion_tokens += usage["completion_tokens"]
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens

return response

Expand Down Expand Up @@ -232,11 +232,11 @@ def generate_choice(

config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)

response, usage = generate_chat(
response, prompt_tokens, completion_tokens = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.completion_tokens += usage["completion_tokens"]
self.prompt_tokens += usage["prompt_tokens"]
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens

encoded_response = tokenizer.encode(response)

Expand Down Expand Up @@ -281,14 +281,13 @@ def __repr__(self):
return str(self.config)


@cache(ignore="client")
@functools.partial(vectorize, signature="(),(),()->(s)")
@functools.partial(vectorize, signature="(),(),(),()->(s),(),()")
async def generate_chat(
prompt: str,
system_prompt: Union[str, None],
client: "AsyncOpenAI",
config: OpenAIConfig,
) -> Tuple[np.ndarray, Dict]:
) -> Tuple[np.ndarray, int, int]:
"""Call OpenAI's Chat Completion API.
Parameters
Expand All @@ -309,19 +308,28 @@ async def generate_chat(
A tuple that contains the model's response(s) and usage statistics.
"""

@cache()
async def call_api(prompt, system_prompt, config):
responses = await client.chat.completions.create(
messages=system_message + user_message,
**asdict(config), # type: ignore
)
return responses.model_dump()

system_message = (
[{"role": "system", "content": system_prompt}] if system_prompt else []
)
user_message = [{"role": "user", "content": prompt}]

responses = await client.chat.completions.create(
messages=system_message + user_message,
**asdict(config), # type: ignore
)
responses = await call_api(prompt, system_prompt, config)

results = np.array([responses.choices[i].message.content for i in range(config.n)])
results = np.array(
[responses["choices"][i]["message"]["content"] for i in range(config.n)]
)
usage = responses["usage"]

return results, responses.usage.model_dump()
return results, usage["prompt_tokens"], usage["completion_tokens"]


openai = OpenAI
Expand Down

0 comments on commit add3743

Please sign in to comment.