diff --git a/src/sparseml/export/export.py b/src/sparseml/export/export.py index 360fd0326f6..c1e0f8b581e 100644 --- a/src/sparseml/export/export.py +++ b/src/sparseml/export/export.py @@ -98,9 +98,12 @@ def export( IntegrationHelperFunctions.load_from_registry(integration) ) - model = helper_functions.create_model(source_path, device) + # for now, this code is not runnable, serves as a blueprint + model, auxiliary_items = helper_functions.create_model( + source_path, **kwargs # noqa: F821 + ) sample_data = ( - helper_functions.create_dummy_input(model, batch_size) + helper_functions.create_dummy_input(**auxiliary_items) if sample_data is None else sample_data ) diff --git a/src/sparseml/integration_helper_functions.py b/src/sparseml/integration_helper_functions.py index a0201301ba8..68b04400e84 100644 --- a/src/sparseml/integration_helper_functions.py +++ b/src/sparseml/integration_helper_functions.py @@ -14,7 +14,7 @@ from enum import Enum from pathlib import Path -from typing import Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Union from pydantic import BaseModel, Field @@ -39,14 +39,27 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): integration. """ - create_model: Optional[Callable] = Field( - description="A function that creates a (sparse) " - "PyTorch model from a source path and additional " - "arguments" + create_model: Optional[ + Callable[ + [Union[str, Path], Optional[Dict[str, Any]]][ + "torch.nn.Module", Optional[Dict[str, Any]] # noqa F821 + ] + ] + ] = Field( + description="A function that takes: " + "- a source path to a PyTorch model " + "- (optionally) a dictionary of additional arguments" + "and returns: " + "- a (sparse) PyTorch model " + "- (optionally) a dictionary of additional arguments" ) - create_dummy_input: Optional[Callable] = Field( - description="A function that creates a dummy input " - "given a (sparse) PyTorch model." + create_dummy_input: Optional[ + Callable[Any]["torch.Tensor"] # noqa F821 + ] = Field( # noqa: F82 + description="A function that takes: " + "- a dictionary of arguments" + "and returns: " + "- a dummy input for the model (a torch.Tensor) " ) export_model: Optional[Callable] = Field( description="A function that exports a (sparse) PyTorch " diff --git a/src/sparseml/pytorch/image_classification/integration_helper_functions.py b/src/sparseml/pytorch/image_classification/integration_helper_functions.py index 9f3ac2ddb95..5753b24c9c9 100644 --- a/src/sparseml/pytorch/image_classification/integration_helper_functions.py +++ b/src/sparseml/pytorch/image_classification/integration_helper_functions.py @@ -13,7 +13,7 @@ # limitations under the License. from pathlib import Path -from typing import Any, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch from pydantic import Field @@ -27,17 +27,46 @@ ) -def create_model(source_path: Union[Path, str], **kwargs) -> torch.nn.Module: +def create_model( + source_path: Union[Path, str], **kwargs +) -> Tuple[torch.nn.Module, Dict[str, Any]]: """ A contract to create a model from a source path :param source_path: The path to the model - :return: The torch model + :param kwargs: Additional kwargs to pass to the model creation function + :return: A tuple of the + - torch model + - additional dictionary of useful objects created during model creation """ - model, *_ = create_image_classification_model(checkpoint_path=source_path, **kwargs) - return model + model, *_, validation_loader = create_image_classification_model( + checkpoint_path=source_path, **kwargs + ) + return model, dict(validation_loader=validation_loader) + + +def create_dummy_input( + validation_loader: Optional[torch.utils.data.DataLoader] = None, + image_size: Optional[int] = 224, +) -> torch.Tensor: + """ + A contract to create a dummy input for a model + + :param validation_loader: The validation loader to get a batch from. + If None, a fake batch will be created + :param image_size: The image size to use for the dummy input. Defaults to 224 + :return: The dummy input as a torch tensor + """ + + if not validation_loader: + # create fake data for export + validation_loader = [[torch.randn(1, 3, image_size, image_size)]] + return next(iter(validation_loader))[0] @IntegrationHelperFunctions.register(name=Integrations.image_classification.value) class ImageClassification(IntegrationHelperFunctions): - create_model: Any = Field(default=create_model) + create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field( + default=create_model + ) + create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input) diff --git a/src/sparseml/pytorch/image_classification/utils/helpers.py b/src/sparseml/pytorch/image_classification/utils/helpers.py index 32e5df660bd..a10b496bc35 100644 --- a/src/sparseml/pytorch/image_classification/utils/helpers.py +++ b/src/sparseml/pytorch/image_classification/utils/helpers.py @@ -446,7 +446,7 @@ def create_model( if one_shot is not None: ScheduledModifierManager.from_yaml(file_path=one_shot).apply(module=model) - return model, arch_key, checkpoint_path + return model, arch_key, checkpoint_path, val_dataset def infer_num_classes( diff --git a/tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py b/tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py index c33df404b0e..851ef5c27e8 100644 --- a/tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py +++ b/tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py @@ -25,3 +25,4 @@ def test_integration_helper_functions(): Integrations.image_classification.value ) assert image_classification.create_model + assert image_classification.create_dummy_input