Skip to content

Commit

Permalink
Add mock inference mode setting and allow testing without gen ai key (#…
Browse files Browse the repository at this point in the history
…1204)

* Add mock inference mode setting

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Update

Signed-off-by: elronbandel <elronbandel@gmail.com>

* remove gen ai key

Signed-off-by: elronbandel <elronbandel@gmail.com>

---------

Signed-off-by: elronbandel <elronbandel@gmail.com>
  • Loading branch information
elronbandel committed Sep 8, 2024
1 parent 7e3caf5 commit ba2f04a
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 9 deletions.
1 change: 0 additions & 1 deletion .github/workflows/catalog_consistency.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ jobs:
runs-on: ubuntu-latest
env:
OS: ubuntu-latest
GENAI_KEY: ${{ secrets.GENAI_KEY }}
UNITXT_DEFAULT_VERBOSITY: error
DATASETS_VERBOSITY: error
HF_HUB_VERBOSITY: error
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/catalog_preparation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ jobs:
runs-on: ubuntu-latest
env:
OS: ubuntu-latest
GENAI_KEY: ${{ secrets.GENAI_KEY }}
UNITXT_DEFAULT_VERBOSITY: error
DATASETS_VERBOSITY: error
HF_HUB_VERBOSITY: error
Expand Down
31 changes: 25 additions & 6 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .image_operators import extract_images
from .logging_utils import get_logger
from .operator import PackageRequirementsMixin
from .settings_utils import get_settings

settings = get_settings()


class InferenceEngine(abc.ABC, Artifact):
Expand All @@ -21,9 +24,20 @@ def _infer(self, dataset):
"""Perform inference on the input dataset."""
pass

@abc.abstractmethod
def prepare_engine(self):
"""Perform inference on the input dataset."""
pass

def prepare(self):
if not settings.mock_inference_mode:
self.prepare_engine()

def infer(self, dataset) -> str:
"""Verifies instances of a dataset and performs inference."""
[self.verify_instance(instance) for instance in dataset]
if settings.mock_inference_mode:
return [instance["source"] for instance in dataset]
return self._infer(dataset)

@deprecation(version="2.0.0")
Expand Down Expand Up @@ -122,7 +136,7 @@ def _prepare_pipeline(self):
model=self.model_name, trust_remote_code=True, **model_args
)

def prepare(self):
def prepare_engine(self):
if not self.lazy_load:
self._prepare_pipeline()

Expand All @@ -144,13 +158,17 @@ def _infer(self, dataset):
class MockInferenceEngine(InferenceEngine):
model_name: str

def prepare(self):
def prepare_engine(self):
return

def _infer(self, dataset):
return ["[[10]]" for instance in dataset]


class MockModeMixin(Artifact):
mock_mode: bool = False


class IbmGenAiInferenceEngineParamsMixin(Artifact):
beam_width: Optional[int] = None
decoding_method: Optional[Literal["greedy", "sample"]] = None
Expand Down Expand Up @@ -201,11 +219,12 @@ class IbmGenAiInferenceEngine(
data_classification_policy = ["public", "proprietary"]
parameters: Optional[IbmGenAiInferenceEngineParams] = None

def prepare(self):
def prepare_engine(self):
from genai import Client, Credentials

api_key_env_var_name = "GENAI_KEY"
api_key = os.environ.get(api_key_env_var_name)

assert api_key is not None, (
f"Error while trying to run IbmGenAiInferenceEngine."
f" Please set the environment param '{api_key_env_var_name}'."
Expand Down Expand Up @@ -279,7 +298,7 @@ class OpenAiInferenceEngine(
data_classification_policy = ["public"]
parameters: Optional[OpenAiInferenceEngineParams] = None

def prepare(self):
def prepare_engine(self):
from openai import OpenAI

api_key_env_var_name = "OPENAI_API_KEY"
Expand Down Expand Up @@ -490,7 +509,7 @@ def _initialize_wml_client(self):
client.set.default_project(self.credentials["project_id"])
return client

def prepare(self):
def prepare_engine(self):
self._client = self._initialize_wml_client()

self._set_inference_parameters()
Expand Down Expand Up @@ -541,7 +560,7 @@ def _prepare_engine(self):

self.processor = AutoProcessor.from_pretrained(self.model_name)

def prepare(self):
def prepare_engine(self):
if not self.lazy_load:
self._prepare_engine()

Expand Down
1 change: 1 addition & 0 deletions src/unitxt/settings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __getattr__(self, key):
settings.seed = (int, 42)
settings.skip_artifacts_prepare_and_verify = (bool, False)
settings.data_classification_policy = None
settings.mock_inference_mode = (bool, False)

if Constants.is_uninitilized():
constants = Constants()
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUpClass(cls):
enable_explicit_format()
unitxt.settings.allow_unverified_code = True
unitxt.settings.use_only_local_catalogs = True
# unitxt.settings.global_loader_limit = 300
unitxt.settings.mock_inference_mode = True
unitxt.settings.max_log_message_size = 1000000000000
if settings.default_verbosity in ["error", "critical"]:
if not sys.warnoptions:
Expand Down

0 comments on commit ba2f04a

Please sign in to comment.