Skip to content

Commit

Permalink
fix: base_url not included when request to gpt in SingleTableGPTModel (
Browse files Browse the repository at this point in the history
…#205)

Co-authored-by: MoooCat <141886018+MooooCat@users.noreply.github.com>
  • Loading branch information
jalr4ever and MooooCat committed Jul 26, 2024
1 parent ad0942f commit 31eefae
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
10 changes: 7 additions & 3 deletions sdgx/models/LLM/single_table/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def _get_openai_setting_from_env(self):
self.openai_API_url = os.getenv("OPENAI_URL")
logger.debug("Get OPENAI_URL from ENV.")

def openai_client(self):
"""
Generate a openai request client.
"""
return openai.OpenAI(api_key=self.openai_API_key, base_url=self.openai_API_url)

def ask_gpt(self, question, model=None):
"""
Sends a question to the GPT model.
Expand All @@ -156,13 +162,11 @@ def ask_gpt(self, question, model=None):
SynthesizerInitError: If the check method fails.
"""
self.check()
api_key = self.openai_API_key
if model:
model = model
else:
model = self.gpt_model
openai.api_key = api_key
client = openai.OpenAI(api_key=api_key)
client = self.openai_client()
logger.info(f"Ask GPT with temperature = {self.temperature}.")
response = client.chat.completions.create(
model=model,
Expand Down
12 changes: 12 additions & 0 deletions tests/models/test_singletable_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ def single_table_gpt_model():
gpt_response_sample_count = [20, 15, 20, 5, 5]


def test_singletable_gpt_model_openapi_setting(single_table_gpt_model: SingleTableGPTModel):
open_ai_key = "sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
open_ai_base = "https://api.mock.openai.base.com"
open_ai_model = "gpt-4o-mini"
single_table_gpt_model.set_openAI_settings(open_ai_base, open_ai_key)
single_table_gpt_model.gpt_model = open_ai_model
client = single_table_gpt_model.openai_client()
assert client.base_url == open_ai_base
assert client.api_key == open_ai_key
assert single_table_gpt_model.gpt_model == open_ai_model


def test_singletable_gpt_model(
single_table_gpt_model: SingleTableGPTModel,
raw_data: pd.DataFrame,
Expand Down

0 comments on commit 31eefae

Please sign in to comment.