Skip to content

Commit

Permalink
Merge branch 'main' into yifanmai/fix-openai-inference-engine
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel committed Sep 8, 2024
2 parents 6491183 + ba2f04a commit 0827a1c
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 14 deletions.
1 change: 0 additions & 1 deletion .github/workflows/catalog_consistency.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ jobs:
runs-on: ubuntu-latest
env:
OS: ubuntu-latest
GENAI_KEY: ${{ secrets.GENAI_KEY }}
UNITXT_DEFAULT_VERBOSITY: error
DATASETS_VERBOSITY: error
HF_HUB_VERBOSITY: error
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/catalog_preparation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ jobs:
runs-on: ubuntu-latest
env:
OS: ubuntu-latest
GENAI_KEY: ${{ secrets.GENAI_KEY }}
UNITXT_DEFAULT_VERBOSITY: error
DATASETS_VERBOSITY: error
HF_HUB_VERBOSITY: error
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ https://github.com/IBM/unitxt/assets/23455264/baef9131-39d4-4164-90b2-05da52919f

### 🦄 Currently on Unitxt Catalog

![NLP Tasks](https://img.shields.io/badge/NLP_tasks-40-blue)
![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-457-blue)
![Templates](https://img.shields.io/badge/Templates-229-blue)
![Formats](https://img.shields.io/badge/Formats-18-blue)
![Metrics](https://img.shields.io/badge/Metrics-98-blue)
![NLP Tasks](https://img.shields.io/badge/NLP_tasks-48-blue)
![Dataset Cards](https://img.shields.io/badge/Dataset_Cards-537-blue)
![Templates](https://img.shields.io/badge/Templates-265-blue)
![Formats](https://img.shields.io/badge/Formats-23-blue)
![Metrics](https://img.shields.io/badge/Metrics-136-blue)

### 🦄 Run Unitxt Exploration Dashboard

Expand Down
10 changes: 10 additions & 0 deletions src/unitxt/deprecation_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import warnings

from .error_utils import UnitxtWarning
from .settings_utils import get_constants, get_settings

constants = get_constants()
Expand Down Expand Up @@ -98,3 +99,12 @@ def decorator(obj):
return depraction_wrapper(func, version, alt_text)

return decorator


def init_warning(msg=""):
# Decorator that raises warning when class is initialized
def decorator(initiated_class):
UnitxtWarning(msg)
return initiated_class

return decorator
31 changes: 25 additions & 6 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .image_operators import extract_images
from .logging_utils import get_logger
from .operator import PackageRequirementsMixin
from .settings_utils import get_settings

settings = get_settings()


class InferenceEngine(abc.ABC, Artifact):
Expand All @@ -21,9 +24,20 @@ def _infer(self, dataset):
"""Perform inference on the input dataset."""
pass

@abc.abstractmethod
def prepare_engine(self):
"""Perform inference on the input dataset."""
pass

def prepare(self):
if not settings.mock_inference_mode:
self.prepare_engine()

def infer(self, dataset) -> str:
"""Verifies instances of a dataset and performs inference."""
[self.verify_instance(instance) for instance in dataset]
if settings.mock_inference_mode:
return [instance["source"] for instance in dataset]
return self._infer(dataset)

@deprecation(version="2.0.0")
Expand Down Expand Up @@ -122,7 +136,7 @@ def _prepare_pipeline(self):
model=self.model_name, trust_remote_code=True, **model_args
)

def prepare(self):
def prepare_engine(self):
if not self.lazy_load:
self._prepare_pipeline()

Expand All @@ -144,13 +158,17 @@ def _infer(self, dataset):
class MockInferenceEngine(InferenceEngine):
model_name: str

def prepare(self):
def prepare_engine(self):
return

def _infer(self, dataset):
return ["[[10]]" for instance in dataset]


class MockModeMixin(Artifact):
mock_mode: bool = False


class IbmGenAiInferenceEngineParamsMixin(Artifact):
beam_width: Optional[int] = None
decoding_method: Optional[Literal["greedy", "sample"]] = None
Expand Down Expand Up @@ -201,11 +219,12 @@ class IbmGenAiInferenceEngine(
data_classification_policy = ["public", "proprietary"]
parameters: Optional[IbmGenAiInferenceEngineParams] = None

def prepare(self):
def prepare_engine(self):
from genai import Client, Credentials

api_key_env_var_name = "GENAI_KEY"
api_key = os.environ.get(api_key_env_var_name)

assert api_key is not None, (
f"Error while trying to run IbmGenAiInferenceEngine."
f" Please set the environment param '{api_key_env_var_name}'."
Expand Down Expand Up @@ -279,7 +298,7 @@ class OpenAiInferenceEngine(
data_classification_policy = ["public"]
parameters: Optional[OpenAiInferenceEngineParams] = None

def prepare(self):
def prepare_engine(self):
from openai import OpenAI

api_key_env_var_name = "OPENAI_API_KEY"
Expand Down Expand Up @@ -497,7 +516,7 @@ def _initialize_wml_client(self):
client.set.default_project(self.credentials["project_id"])
return client

def prepare(self):
def prepare_engine(self):
self._client = self._initialize_wml_client()

self._set_inference_parameters()
Expand Down Expand Up @@ -548,7 +567,7 @@ def _prepare_engine(self):

self.processor = AutoProcessor.from_pretrained(self.model_name)

def prepare(self):
def prepare_engine(self):
if not self.lazy_load:
self._prepare_engine()

Expand Down
1 change: 1 addition & 0 deletions src/unitxt/settings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __getattr__(self, key):
settings.seed = (int, 42)
settings.skip_artifacts_prepare_and_verify = (bool, False)
settings.data_classification_policy = None
settings.mock_inference_mode = (bool, False)

if Constants.is_uninitilized():
constants = Constants()
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUpClass(cls):
enable_explicit_format()
unitxt.settings.allow_unverified_code = True
unitxt.settings.use_only_local_catalogs = True
# unitxt.settings.global_loader_limit = 300
unitxt.settings.mock_inference_mode = True
unitxt.settings.max_log_message_size = 1000000000000
if settings.default_verbosity in ["error", "critical"]:
if not sys.warnoptions:
Expand Down

0 comments on commit 0827a1c

Please sign in to comment.