Skip to content

Commit

Permalink
Register more OpenAI models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567694333
  • Loading branch information
daiyip authored and langfun authors committed Sep 22, 2023
1 parent 10b4d2a commit f8833ce
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 9 deletions.
13 changes: 12 additions & 1 deletion langfun/core/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,21 @@
from langfun.core.llms.openai import OpenAI

from langfun.core.llms.openai import Gpt4
from langfun.core.llms.openai import Gpt4_0613
from langfun.core.llms.openai import Gpt4_0314
from langfun.core.llms.openai import Gpt4_32K
from langfun.core.llms.openai import Gpt35
from langfun.core.llms.openai import Gpt4_32K_0613
from langfun.core.llms.openai import Gpt4_32K_0314

from langfun.core.llms.openai import Gpt35Turbo
from langfun.core.llms.openai import Gpt35Turbo_0613
from langfun.core.llms.openai import Gpt35Turbo_0301
from langfun.core.llms.openai import Gpt35Turbo16K
from langfun.core.llms.openai import Gpt35Turbo16K_0613
from langfun.core.llms.openai import Gpt35Turbo16K_0301

from langfun.core.llms.openai import Gpt35

from langfun.core.llms.openai import Gpt3
from langfun.core.llms.openai import Gpt3Curie
from langfun.core.llms.openai import Gpt3Babbage
Expand Down
56 changes: 48 additions & 8 deletions langfun/core/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,16 @@ class OpenAI(lf.LanguageModel):
model: pg.typing.Annotated[
Literal[
'gpt-4',
'gpt-4-0613',
'gpt-4-0314',
'gpt-4-32k',
'gpt-4-32k-0613',
'gpt-4-32k-0314',
'gpt-3.5-turbo',
'gpt-3.5-turbo-0613',
'gpt-3.5-turbo-0301',
'gpt-3.5-turbo-16k',
'gpt-3.5-turbo-16k-0613',
'text-davinci-003',
'davinci',
'curie',
Expand All @@ -54,13 +61,6 @@ class OpenAI(lf.LanguageModel):
'The name of the model to use.',
] = 'gpt-3.5-turbo'

_CHAT_MODELS = frozenset([
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
])

api_key: Annotated[
str | None,
(
Expand Down Expand Up @@ -103,7 +103,7 @@ def dir(cls):
@property
def is_chat_model(self):
"""Returns True if the model is a chat model."""
return self.model in OpenAI._CHAT_MODELS
return self.model.startswith(('gpt-4', 'gpt-3.5-turbo'))

def _get_request_args(
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
Expand Down Expand Up @@ -208,11 +208,31 @@ class Gpt4(OpenAI):
model = 'gpt-4'


class Gpt4_0613(Gpt4): # pylint:disable=invalid-name
"""GPT-4 0613."""
model = 'gpt-4-0613'


class Gpt4_0314(Gpt4): # pylint:disable=invalid-name
"""GPT-4 0314."""
model = 'gpt-4-0314'


class Gpt4_32K(Gpt4): # pylint:disable=invalid-name
"""GPT-4 with 32K context window size."""
model = 'gpt-4-32k'


class Gpt4_32K_0613(Gpt4_32K): # pylint:disable=invalid-name
"""GPT-4 32K 0613."""
model = 'gpt-4-32k-0613'


class Gpt4_32K_0314(Gpt4_32K): # pylint:disable=invalid-name
"""GPT-4 32K 0314."""
model = 'gpt-4-32k-0314'


class Gpt35(OpenAI):
"""GPT-3.5. 4K max tokens, trained up on data up to Sep, 2021."""
model = 'text-davinci-003'
Expand All @@ -223,11 +243,31 @@ class Gpt35Turbo(Gpt35):
model = 'gpt-3.5-turbo'


class Gpt35Turbo_0613(Gpt35Turbo): # pylint:disable=invalid-name
"""Gtp 3.5 Turbo 0613."""
model = 'gpt-3.5-turbo-0613'


class Gpt35Turbo_0301(Gpt35Turbo): # pylint:disable=invalid-name
"""Gtp 3.5 Turbo 0301."""
model = 'gpt-3.5-turbo-0301'


class Gpt35Turbo16K(Gpt35Turbo):
"""Latest GPT-3.5 model with 16K context window size."""
model = 'gpt-3.5-turbo-16k'


class Gpt35Turbo16K_0613(Gpt35Turbo): # pylint:disable=invalid-name
"""Gtp 3.5 Turbo 16K 0613."""
model = 'gpt-3.5-turbo-16k-0613'


class Gpt35Turbo16K_0301(Gpt35Turbo): # pylint:disable=invalid-name
"""Gtp 3.5 Turbo 16K 0301."""
model = 'gpt-3.5-turbo-16k-0301'


class Gpt3(OpenAI):
"""Most capable GPT-3 model (Davinci) 2K context window size.
Expand Down

0 comments on commit f8833ce

Please sign in to comment.