From 130cbbff867592a2301029a0b74f2707ed5f790e Mon Sep 17 00:00:00 2001 From: Jayesh Sharma Date: Wed, 4 Sep 2024 17:27:03 +0530 Subject: [PATCH] add support for multi modle comparison --- .../deepchecks_data_validator.py | 68 +++++++++++++++---- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/src/zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py b/src/zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py index 7fde074d4fb..09d8cae43ce 100644 --- a/src/zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +++ b/src/zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py @@ -17,6 +17,7 @@ Any, ClassVar, Dict, + List, Optional, Sequence, Tuple, @@ -30,6 +31,8 @@ from deepchecks.tabular import Dataset as TabularData from deepchecks.tabular import Suite as TabularSuite +from deepchecks.tabular import ModelComparisonSuite + # not part of deepchecks.tabular.checks from deepchecks.tabular.suites import full_suite as full_tabular_suite from deepchecks.vision import Suite as VisionSuite @@ -102,7 +105,7 @@ def _create_and_run_check_suite( comparison_dataset: Optional[ Union[pd.DataFrame, DataLoader[Any]] ] = None, - model: Optional[Union[ClassifierMixin, Module]] = None, + models: Optional[List[Union[ClassifierMixin, Module]]] = None, check_list: Optional[Sequence[str]] = None, dataset_kwargs: Dict[str, Any] = {}, check_kwargs: Dict[str, Dict[str, Any]] = {}, @@ -123,7 +126,7 @@ def _create_and_run_check_suite( validation. comparison_dataset: Optional secondary (comparison) dataset argument used during comparison checks. - model: Optional model argument used during validation. + models: Optional model argument used during validation. check_list: Optional list of ZenML Deepchecks check identifiers specifying the list of Deepchecks checks to be performed. dataset_kwargs: Additional keyword arguments to be passed to the @@ -149,6 +152,7 @@ def _create_and_run_check_suite( # arguments and the check list. is_tabular = False is_vision = False + is_multi_model = False for dataset in [reference_dataset, comparison_dataset]: if dataset is None: continue @@ -163,7 +167,18 @@ def _create_and_run_check_suite( f"data and {str(DataLoader)} for computer vision data." ) - if model: + if models: + # if there's more than one models, we should set the + # is_multi_model to True + if len(models) > 1: + is_multi_model = True + # if the models are of different types, raise an error + # only the same type of models can be used for comparison + if len(set(type(model) for model in models)) > 1: + raise TypeError( + f"Models used for comparison checks must be of the same type." + ) + model = models[0] if isinstance(model, ClassifierMixin): is_tabular = True elif isinstance(model, Module): @@ -190,8 +205,18 @@ def _create_and_run_check_suite( if not check_list: # default to executing all the checks listed in the supplied # checks enum type if a custom check list is not supplied + # don't include the TABULAR_PERFORMANCE_BIAS check enum value + # as it requires a protected feature name to be set + checks_to_exclude = [ + DeepchecksModelValidationCheck.TABULAR_PERFORMANCE_BIAS + ] + check_enum_values = [ + check.value + for check in check_enum + if check not in checks_to_exclude + ] tabular_checks, vision_checks = cls._split_checks( - check_enum.values() + check_enum_values ) if is_tabular: check_list = tabular_checks @@ -254,6 +279,10 @@ def _create_and_run_check_suite( suite_class = VisionSuite full_suite = full_vision_suite() + # if is_multi_model is True, we need to use the ModelComparisonSuite + if is_multi_model: + suite_class = ModelComparisonSuite + train_dataset = dataset_class(reference_dataset, **dataset_kwargs) test_dataset = None if comparison_dataset is not None: @@ -294,13 +323,28 @@ def _create_and_run_check_suite( continue condition_method(**condition_kwargs) - suite.add(check) - return suite.run( - train_dataset=train_dataset, - test_dataset=test_dataset, - model=model, - **run_kwargs, - ) + # if the check is supported by the suite, add it + if isinstance(check, suite.supported_checks()): + suite.add(check) + else: + logger.warning( + f"Check {check_name} is not supported by the {suite_class} " + "suite. Ignoring the check." + ) + + if isinstance(suite, ModelComparisonSuite): + return suite.run( + models=models, + train_datasets=train_dataset, + test_datasets=test_dataset, + ) + else: + return suite.run( + train_dataset=train_dataset, + test_dataset=test_dataset, + model=models[0] if models else None, + **run_kwargs, + ) def data_validation( self, @@ -444,7 +488,7 @@ def model_validation( check_enum=check_enum, reference_dataset=dataset, comparison_dataset=comparison_dataset, - model=model, + models=[model], check_list=check_list, dataset_kwargs=dataset_kwargs, check_kwargs=check_kwargs,