From 219bb14337781f0b13a69a32443e011445e6f3c3 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Mon, 2 Sep 2024 20:06:28 -0700 Subject: [PATCH] Fix OpenAiInferenceEngine Signed-off-by: Yifan Mai --- src/unitxt/inference.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index c6b6b26e0..6bda893e7 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -242,9 +242,9 @@ class OpenAiInferenceEngineParamsMixin(Artifact): top_p: Optional[float] = None top_logprobs: Optional[int] = 20 logit_bias: Optional[Dict[str, int]] = None - logprobs: Optional[bool] = None + logprobs: Optional[bool] = True n: Optional[int] = None - parallel_tool_calls: bool = None + parallel_tool_calls: Optional[bool] = None service_tier: Optional[Literal["auto", "default"]] = None @@ -259,9 +259,9 @@ class OpenAiInferenceEngineParams(Artifact): top_p: Optional[float] = None top_logprobs: Optional[int] = 20 logit_bias: Optional[Dict[str, int]] = None - logprobs: Optional[bool] = None + logprobs: Optional[bool] = True n: Optional[int] = None - parallel_tool_calls: bool = None + parallel_tool_calls: Optional[bool] = None service_tier: Optional[Literal["auto", "default"]] = None @@ -293,6 +293,13 @@ def prepare(self): self._set_inference_parameters() + def _get_completion_kwargs(self): + return { + k: v + for k, v in self.to_dict([OpenAiInferenceEngineParamsMixin]).items() + if v is not None + } + def _infer(self, dataset): outputs = [] for instance in tqdm(dataset, desc="Inferring with openAI API"): @@ -308,7 +315,7 @@ def _infer(self, dataset): } ], model=self.model_name, - **self.to_dict([OpenAiInferenceEngineParamsMixin]), + **self._get_completion_kwargs(), ) output = response.choices[0].message.content @@ -331,7 +338,7 @@ def _infer_log_probs(self, dataset): } ], model=self.model_name, - **self.to_dict([OpenAiInferenceEngineParamsMixin]), + **self._get_completion_kwargs(), ) top_logprobs_response = response.choices[0].logprobs.content output = [