diff --git a/.github/workflows/catalog_consistency.yml b/.github/workflows/catalog_consistency.yml index 4ea4005e1..20de8d87f 100644 --- a/.github/workflows/catalog_consistency.yml +++ b/.github/workflows/catalog_consistency.yml @@ -12,7 +12,6 @@ jobs: runs-on: ubuntu-latest env: OS: ubuntu-latest - GENAI_KEY: ${{ secrets.GENAI_KEY }} UNITXT_DEFAULT_VERBOSITY: error DATASETS_VERBOSITY: error HF_HUB_VERBOSITY: error diff --git a/.github/workflows/catalog_preparation.yml b/.github/workflows/catalog_preparation.yml index a3024f409..468513f30 100644 --- a/.github/workflows/catalog_preparation.yml +++ b/.github/workflows/catalog_preparation.yml @@ -12,7 +12,6 @@ jobs: runs-on: ubuntu-latest env: OS: ubuntu-latest - GENAI_KEY: ${{ secrets.GENAI_KEY }} UNITXT_DEFAULT_VERBOSITY: error DATASETS_VERBOSITY: error HF_HUB_VERBOSITY: error diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index c6b6b26e0..c35bda5fe 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -11,6 +11,9 @@ from .image_operators import extract_images from .logging_utils import get_logger from .operator import PackageRequirementsMixin +from .settings_utils import get_settings + +settings = get_settings() class InferenceEngine(abc.ABC, Artifact): @@ -21,9 +24,20 @@ def _infer(self, dataset): """Perform inference on the input dataset.""" pass + @abc.abstractmethod + def prepare_engine(self): + """Perform inference on the input dataset.""" + pass + + def prepare(self): + if not settings.mock_inference_mode: + self.prepare_engine() + def infer(self, dataset) -> str: """Verifies instances of a dataset and performs inference.""" [self.verify_instance(instance) for instance in dataset] + if settings.mock_inference_mode: + return [instance["source"] for instance in dataset] return self._infer(dataset) @deprecation(version="2.0.0") @@ -122,7 +136,7 @@ def _prepare_pipeline(self): model=self.model_name, trust_remote_code=True, **model_args ) - def prepare(self): + def prepare_engine(self): if not self.lazy_load: self._prepare_pipeline() @@ -144,13 +158,17 @@ def _infer(self, dataset): class MockInferenceEngine(InferenceEngine): model_name: str - def prepare(self): + def prepare_engine(self): return def _infer(self, dataset): return ["[[10]]" for instance in dataset] +class MockModeMixin(Artifact): + mock_mode: bool = False + + class IbmGenAiInferenceEngineParamsMixin(Artifact): beam_width: Optional[int] = None decoding_method: Optional[Literal["greedy", "sample"]] = None @@ -201,11 +219,12 @@ class IbmGenAiInferenceEngine( data_classification_policy = ["public", "proprietary"] parameters: Optional[IbmGenAiInferenceEngineParams] = None - def prepare(self): + def prepare_engine(self): from genai import Client, Credentials api_key_env_var_name = "GENAI_KEY" api_key = os.environ.get(api_key_env_var_name) + assert api_key is not None, ( f"Error while trying to run IbmGenAiInferenceEngine." f" Please set the environment param '{api_key_env_var_name}'." @@ -279,7 +298,7 @@ class OpenAiInferenceEngine( data_classification_policy = ["public"] parameters: Optional[OpenAiInferenceEngineParams] = None - def prepare(self): + def prepare_engine(self): from openai import OpenAI api_key_env_var_name = "OPENAI_API_KEY" @@ -490,7 +509,7 @@ def _initialize_wml_client(self): client.set.default_project(self.credentials["project_id"]) return client - def prepare(self): + def prepare_engine(self): self._client = self._initialize_wml_client() self._set_inference_parameters() @@ -541,7 +560,7 @@ def _prepare_engine(self): self.processor = AutoProcessor.from_pretrained(self.model_name) - def prepare(self): + def prepare_engine(self): if not self.lazy_load: self._prepare_engine() diff --git a/src/unitxt/settings_utils.py b/src/unitxt/settings_utils.py index c6bbd8eac..9018a806c 100644 --- a/src/unitxt/settings_utils.py +++ b/src/unitxt/settings_utils.py @@ -146,6 +146,7 @@ def __getattr__(self, key): settings.seed = (int, 42) settings.skip_artifacts_prepare_and_verify = (bool, False) settings.data_classification_policy = None + settings.mock_inference_mode = (bool, False) if Constants.is_uninitilized(): constants = Constants() diff --git a/tests/utils.py b/tests/utils.py index 1c0c1a9c6..36b76fa73 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,7 +33,7 @@ def setUpClass(cls): enable_explicit_format() unitxt.settings.allow_unverified_code = True unitxt.settings.use_only_local_catalogs = True - # unitxt.settings.global_loader_limit = 300 + unitxt.settings.mock_inference_mode = True unitxt.settings.max_log_message_size = 1000000000000 if settings.default_verbosity in ["error", "critical"]: if not sys.warnoptions: