diff --git a/outlines/models/openai.py b/outlines/models/openai.py index cedb37420..8625f4d3d 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -414,10 +414,16 @@ def call(*args, **kwargs): return call -def openai( +@functools.singledispatch +def openai(model_or_client, *args, **kwargs): + return OpenAI(model_or_client, *args, **kwargs) + + +@openai.register(str) +def openai_model( model_name: str, - api_key: Optional[str] = None, config: Optional[OpenAIConfig] = None, + **openai_client_params, ): try: import tiktoken @@ -432,7 +438,7 @@ def openai( else: config = OpenAIConfig(model=model_name) - client = AsyncOpenAI(api_key=api_key) + client = AsyncOpenAI(**openai_client_params) tokenizer = tiktoken.encoding_for_model(model_name) return OpenAI(client, config, tokenizer) @@ -441,10 +447,8 @@ def openai( def azure_openai( deployment_name: str, model_name: Optional[str] = None, - azure_endpoint: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, config: Optional[OpenAIConfig] = None, + **azure_openai_client_params, ): try: import tiktoken @@ -459,9 +463,7 @@ def azure_openai( if config is None: config = OpenAIConfig(model=deployment_name) - client = AsyncAzureOpenAI( - azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key - ) + client = AsyncAzureOpenAI(**azure_openai_client_params) tokenizer = tiktoken.encoding_for_model(model_name or deployment_name) return OpenAI(client, config, tokenizer)