Skip to content

Commit

Permalink
automatically search for openai key
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed May 21, 2024
1 parent de3f0f9 commit a1ba119
Showing 1 changed file with 9 additions and 74 deletions.
83 changes: 9 additions & 74 deletions imodelsx/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
from scipy.special import softmax
import hashlib
import torch
from os.path import expanduser
import time
from tqdm import tqdm

HF_TOKEN = None
if 'HF_TOKEN' in os.environ:
HF_TOKEN = os.environ.get("HF_TOKEN")
elif os.path.exists('~/.HF_TOKEN'):
HF_TOKEN = open(os.path.expanduser('~/.HF_TOKEN'), 'r').read().strip()
elif os.path.exists(expanduser('~/.HF_TOKEN')):
HF_TOKEN = open(expanduser('~/.HF_TOKEN'), 'r').read().strip()
if 'OPENAI_API_KEY' in os.environ:
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
elif os.path.exists(expanduser('~/.OPENAI_API_KEY')):
OPENAI_API_KEY = open(expanduser('~/.OPENAI_API_KEY'), 'r').read().strip()
'''
Example usage:
# gpt-4, gpt-35-turbo, meta-llama/Llama-2-70b-hf, mistralai/Mistral-7B-v0.1
Expand Down Expand Up @@ -58,9 +63,7 @@ def get_llm(
LLM_CONFIG["LLM_REPEAT_DELAY"] = repeat_delay

"""Get an LLM with a call function and caching capabilities"""
if checkpoint.startswith("text-da"):
return LLM_OpenAI(checkpoint, seed=seed, CACHE_DIR=CACHE_DIR)
elif checkpoint.startswith("gpt-3") or checkpoint.startswith("gpt-4"):
if checkpoint.startswith("gpt-3") or checkpoint.startswith("gpt-4"):
return LLM_Chat(checkpoint, seed, role, CACHE_DIR)
elif 'Meta-Llama-3' in checkpoint and 'Instruct' in checkpoint:
return LLM_HF_Pipeline(checkpoint, CACHE_DIR)
Expand Down Expand Up @@ -99,73 +102,6 @@ def wrapper(*args, **kwargs):
return wrapper


class LLM_OpenAI:
def __init__(self, checkpoint, seed, CACHE_DIR):
self.cache_dir = join(
CACHE_DIR, "cache_openai", f'{checkpoint.replace("/", "_")}___{seed}'
)
self.checkpoint = checkpoint

@repeatedly_call_with_delay
def __call__(
self,
prompt: str,
max_new_tokens=250,
frequency_penalty=0.25, # maximum is 2
temperature=0.1,
do_sample=True,
stop=None,
return_str=True,
):

# cache
os.makedirs(self.cache_dir, exist_ok=True)
hash_str = hashlib.sha256(prompt.encode()).hexdigest()
cache_file = join(
self.cache_dir, f"{hash_str}__num_tok={max_new_tokens}.pkl")
if os.path.exists(cache_file):
return pkl.load(open(cache_file, "rb"))

# import openai
# response = openai.Completion.create(
# engine=self.checkpoint,
# prompt=prompt,
# max_tokens=max_new_tokens,
# temperature=temperature,
# top_p=1,
# frequency_penalty=frequency_penalty,
# presence_penalty=0,
# stop=stop,
# # stop=["101"]
# )
# if return_str:
# response = response["choices"][0]["text"]

from openai import AzureOpenAI
api_key = os.getenv("OPENAI_API_KEY") # need to fill this in
client = AzureOpenAI(
azure_endpoint="https://healthcare-ai.openai.azure.com/",
api_version="2024-02-01",
api_key=api_key,
)

response = client.chat.completions.create( # replace this value with the deployment name you chose when you deployed the associated model.
model=self.checkpoint,
messages=prompt,
temperature=temperature,
max_tokens=max_new_tokens,
top_p=1,
frequency_penalty=frequency_penalty,
presence_penalty=0,
stop=None)

if return_str:
response = response.choices[0].message.content

pkl.dump(response, open(cache_file, "wb"))
return response


class LLM_Chat:
"""Chat models take a different format: https://platform.openai.com/docs/guides/chat/introduction"""

Expand All @@ -176,11 +112,10 @@ def __init__(self, checkpoint, seed, role, CACHE_DIR):
self.checkpoint = checkpoint
self.role = role
from openai import AzureOpenAI
api_key = os.getenv("OPENAI_API_KEY") # need to fill this in
self.client = AzureOpenAI(
azure_endpoint="https://healthcare-ai.openai.azure.com/",
api_version="2024-02-01",
api_key=api_key,
api_key=OPENAI_API_KEY,
)

# @repeatedly_call_with_delay
Expand Down

0 comments on commit a1ba119

Please sign in to comment.