Skip to content

Commit

Permalink
Use the OpenAI API to check that a model exists
Browse files Browse the repository at this point in the history
  • Loading branch information
remic33 authored and rlouf committed Jan 10, 2024
1 parent 02fe25a commit 37f53ca
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
11 changes: 7 additions & 4 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ def __init__(
parameters that cannot be set by calling this class' methods.
"""
if model_name not in ["gpt-4", "gpt-3.5-turbo"]:
raise ValueError(
"Invalid model_name. It must be either 'gpt-4' or 'gpt-3.5-turbo'."
)

try:
import openai
Expand All @@ -125,6 +121,13 @@ def __init__(
raise ValueError(
"You must specify an API key to use the OpenAI API integration."
)
try:
client = openai.OpenAI()
client.models.retrieve(model_name)
except openai.NotFoundError:
raise ValueError(
"Invalid model_name. Check openai models list at https://platform.openai.com/docs/models"
)

if config is not None:
self.config = replace(config, model=model_name) # type: ignore
Expand Down
9 changes: 0 additions & 9 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from outlines.models.openai import (
OpenAI,
build_optimistic_mask,
find_longest_intersection,
find_response_choices_intersection,
Expand Down Expand Up @@ -49,11 +48,3 @@ def test_find_longest_common_prefix(response, choice, expected_prefix):
def test_build_optimistic_mask(transposed, mask_size, expected_mask):
mask = build_optimistic_mask(transposed, mask_size)
assert mask == expected_mask


def test_model_name_validation():
with pytest.raises(ValueError):
OpenAI(model_name="invalid_model_name")

with pytest.raises(ValueError):
OpenAI(model_name="gpt-4-1106-preview")

0 comments on commit 37f53ca

Please sign in to comment.