Skip to content

Commit

Permalink
Allow to pass all AsyncOpenai client parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 14, 2024
1 parent 18aaba1 commit ed8e6ec
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,16 @@ def call(*args, **kwargs):
return call


def openai(
@functools.singledispatch
def openai(model_or_client, *args, **kwargs):
return OpenAI(model_or_client, *args, **kwargs)


@openai.register(str)
def openai_model(
model_name: str,
api_key: Optional[str] = None,
config: Optional[OpenAIConfig] = None,
**openai_client_params,
):
try:
import tiktoken
Expand All @@ -432,7 +438,7 @@ def openai(
else:
config = OpenAIConfig(model=model_name)

client = AsyncOpenAI(api_key=api_key)
client = AsyncOpenAI(**openai_client_params)
tokenizer = tiktoken.encoding_for_model(model_name)

return OpenAI(client, config, tokenizer)
Expand All @@ -441,10 +447,8 @@ def openai(
def azure_openai(
deployment_name: str,
model_name: Optional[str] = None,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
config: Optional[OpenAIConfig] = None,
**azure_openai_client_params,
):
try:
import tiktoken
Expand All @@ -459,9 +463,7 @@ def azure_openai(
if config is None:
config = OpenAIConfig(model=deployment_name)

client = AsyncAzureOpenAI(
azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key
)
client = AsyncAzureOpenAI(**azure_openai_client_params)
tokenizer = tiktoken.encoding_for_model(model_name or deployment_name)

return OpenAI(client, config, tokenizer)

0 comments on commit ed8e6ec

Please sign in to comment.