Skip to content

Commit

Permalink
Merge branch 'main' into yifanmai/fix-openai-inference-engine
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel committed Sep 8, 2024
2 parents e53722e + c43f57c commit 8ee0d76
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,96 @@ def _infer_log_probs(self, dataset):
return outputs


class TogetherAiInferenceEngineParamsMixin(Artifact):
max_tokens: Optional[int] = None
stop: Optional[List[str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
logprobs: Optional[int] = None
echo: Optional[bool] = None
n: Optional[int] = None
min_p: Optional[float] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None


class TogetherAiInferenceEngine(
InferenceEngine, TogetherAiInferenceEngineParamsMixin, PackageRequirementsMixin
):
label: str = "together"
model_name: str
_requirements_list = {
"together": "Install together package using 'pip install --upgrade together"
}
data_classification_policy = ["public"]
parameters: Optional[TogetherAiInferenceEngineParamsMixin] = None

def prepare_engine(self):
from together import Together
from together.types.models import ModelType

api_key_env_var_name = "TOGETHER_API_KEY"
api_key = os.environ.get(api_key_env_var_name)
assert api_key is not None, (
f"Error while trying to run TogetherAiInferenceEngine."
f" Please set the environment param '{api_key_env_var_name}'."
)
self.client = Together(api_key=api_key)
self._set_inference_parameters()

# Get model type from Together List Models API
together_models = self.client.models.list()
together_model_id_to_type = {
together_model.id: together_model.type for together_model in together_models
}
model_type = together_model_id_to_type.get(self.model_name)
assert model_type is not None, (
f"Could not find model {self.model_name} " "in Together AI model list"
)
assert model_type in [ModelType.CHAT, ModelType.LANGUAGE, ModelType.CODE], (
f"Together AI model type {model_type} is not supported; "
"supported types are 'chat', 'language' and 'code'."
)
self.model_type = model_type

def _get_infer_kwargs(self):
return {
k: v
for k, v in self.to_dict([TogetherAiInferenceEngineParamsMixin]).items()
if v is not None
}

def _infer_chat(self, prompt: str) -> str:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
**self._get_infer_kwargs(),
)
return response.choices[0].message.content

def _infer_text(self, prompt: str) -> str:
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
**self._get_infer_kwargs(),
)
return response.choices[0].text

def _infer(self, dataset):
from together.types.models import ModelType

outputs = []
if self.model_type == ModelType.CHAT:
for instance in tqdm(dataset, desc="Inferring with Together AI Chat API"):
outputs.append(self._infer_chat(instance["source"]))
else:
for instance in tqdm(dataset, desc="Inferring with Together AI Text API"):
outputs.append(self._infer_text(instance["source"]))
return outputs


class WMLInferenceEngineParamsMixin(Artifact):
decoding_method: Optional[Literal["greedy", "sample"]] = None
length_penalty: Optional[Dict[str, Union[int, float]]] = None
Expand Down

0 comments on commit 8ee0d76

Please sign in to comment.