Skip to content

Commit

Permalink
add support for multi modle comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
wjayesh committed Sep 4, 2024
1 parent 281e7f7 commit 130cbbf
Showing 1 changed file with 56 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Any,
ClassVar,
Dict,
List,
Optional,
Sequence,
Tuple,
Expand All @@ -30,6 +31,8 @@
from deepchecks.tabular import Dataset as TabularData
from deepchecks.tabular import Suite as TabularSuite

from deepchecks.tabular import ModelComparisonSuite

# not part of deepchecks.tabular.checks
from deepchecks.tabular.suites import full_suite as full_tabular_suite
from deepchecks.vision import Suite as VisionSuite
Expand Down Expand Up @@ -102,7 +105,7 @@ def _create_and_run_check_suite(
comparison_dataset: Optional[
Union[pd.DataFrame, DataLoader[Any]]
] = None,
model: Optional[Union[ClassifierMixin, Module]] = None,
models: Optional[List[Union[ClassifierMixin, Module]]] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
Expand All @@ -123,7 +126,7 @@ def _create_and_run_check_suite(
validation.
comparison_dataset: Optional secondary (comparison) dataset argument
used during comparison checks.
model: Optional model argument used during validation.
models: Optional model argument used during validation.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the list of Deepchecks checks to be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Expand All @@ -149,6 +152,7 @@ def _create_and_run_check_suite(
# arguments and the check list.
is_tabular = False
is_vision = False
is_multi_model = False
for dataset in [reference_dataset, comparison_dataset]:
if dataset is None:
continue
Expand All @@ -163,7 +167,18 @@ def _create_and_run_check_suite(
f"data and {str(DataLoader)} for computer vision data."
)

if model:
if models:
# if there's more than one models, we should set the
# is_multi_model to True
if len(models) > 1:
is_multi_model = True
# if the models are of different types, raise an error
# only the same type of models can be used for comparison
if len(set(type(model) for model in models)) > 1:
raise TypeError(
f"Models used for comparison checks must be of the same type."
)
model = models[0]
if isinstance(model, ClassifierMixin):
is_tabular = True
elif isinstance(model, Module):
Expand All @@ -190,8 +205,18 @@ def _create_and_run_check_suite(
if not check_list:
# default to executing all the checks listed in the supplied
# checks enum type if a custom check list is not supplied
# don't include the TABULAR_PERFORMANCE_BIAS check enum value
# as it requires a protected feature name to be set
checks_to_exclude = [
DeepchecksModelValidationCheck.TABULAR_PERFORMANCE_BIAS
]
check_enum_values = [
check.value
for check in check_enum
if check not in checks_to_exclude
]
tabular_checks, vision_checks = cls._split_checks(
check_enum.values()
check_enum_values
)
if is_tabular:
check_list = tabular_checks
Expand Down Expand Up @@ -254,6 +279,10 @@ def _create_and_run_check_suite(
suite_class = VisionSuite
full_suite = full_vision_suite()

# if is_multi_model is True, we need to use the ModelComparisonSuite
if is_multi_model:
suite_class = ModelComparisonSuite

train_dataset = dataset_class(reference_dataset, **dataset_kwargs)
test_dataset = None
if comparison_dataset is not None:
Expand Down Expand Up @@ -294,13 +323,28 @@ def _create_and_run_check_suite(
continue
condition_method(**condition_kwargs)

suite.add(check)
return suite.run(
train_dataset=train_dataset,
test_dataset=test_dataset,
model=model,
**run_kwargs,
)
# if the check is supported by the suite, add it
if isinstance(check, suite.supported_checks()):
suite.add(check)
else:
logger.warning(
f"Check {check_name} is not supported by the {suite_class} "
"suite. Ignoring the check."
)

if isinstance(suite, ModelComparisonSuite):
return suite.run(
models=models,
train_datasets=train_dataset,
test_datasets=test_dataset,
)
else:
return suite.run(
train_dataset=train_dataset,
test_dataset=test_dataset,
model=models[0] if models else None,
**run_kwargs,
)

def data_validation(
self,
Expand Down Expand Up @@ -444,7 +488,7 @@ def model_validation(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
model=model,
models=[model],
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
Expand Down

0 comments on commit 130cbbf

Please sign in to comment.