Skip to content

Commit

Permalink
Create a new AsyncOpenai client for each request
Browse files Browse the repository at this point in the history
The client currently returns a `TimeOutError` after many requests. This
seems to be a problem on the OpenAI side, but we provide this temporary
fix so the OpenAI integration works fine in Outlines.
  • Loading branch information
rlouf committed Dec 14, 2023
1 parent 4f34714 commit ec94e7a
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,15 @@ def __init__(
else:
self.config = OpenAIConfig(model=model_name)

self.client = openai.AsyncOpenAI(
api_key=api_key, max_retries=max_retries, timeout=timeout
# This is necesssary because of an issue with the OpenAI API.
# Status updates: https://github.com/openai/openai-python/issues/769
self.create_client = functools.partial(
openai.AsyncOpenAI,
api_key=api_key,
max_retries=max_retries,
timeout=timeout,
)

self.system_prompt = system_prompt

# We count the total number of prompt and generated tokens as returned
Expand Down Expand Up @@ -173,8 +179,9 @@ def __call__(
)
)
if "gpt-" in self.config.model:
client = self.create_client()
response, prompt_tokens, completion_tokens = generate_chat(
prompt, self.system_prompt, self.client, config
prompt, self.system_prompt, client, config
)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
Expand Down Expand Up @@ -232,8 +239,9 @@ def generate_choice(

config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)

client = self.create_client()
response, prompt_tokens, completion_tokens = generate_chat(
prompt, self.system_prompt, self.client, config
prompt, self.system_prompt, client, config
)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
Expand Down Expand Up @@ -315,6 +323,8 @@ async def call_api(prompt, system_prompt, config):
messages=system_message + user_message,
**asdict(config), # type: ignore
)
await client.close()

return responses.model_dump()

system_message = (
Expand Down

0 comments on commit ec94e7a

Please sign in to comment.