Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor][Image Classification] create_model function #1878

13 changes: 13 additions & 0 deletions src/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
55 changes: 2 additions & 53 deletions src/sparseml/export.py → src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,16 @@

import logging
from pathlib import Path
from typing import Any, Callable, List, Optional, Union

from pydantic import BaseModel, Field
from typing import Any, List, Optional, Union

from sparseml.export.registry import IntegrationHelperFunctions, infer_integration
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparsezoo.utils.registry import RegistryMixin


_LOGGER = logging.getLogger(__name__)
AVAILABLE_DEPLOYMENT_TARGETS = ["deepsparse", "onnxruntime"]


class IntegrationHelperFunctions(BaseModel, RegistryMixin):
"""
Registry that maps integration names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path."
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)


def export(
source_path: Union[Path, str],
target_path: Union[Path, str],
Expand Down Expand Up @@ -166,20 +129,6 @@ def export(
)


def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.
For example:
- for transformers model the integration
can be inferred from `config.json`
- for computer vision, the integration
can be inferred from the model architecture (`arch_key`)
:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
raise NotImplementedError


def validate_correctness(deployment_path: Union[Path, str]):
"""
Validate the correctness of the exported model.
Expand Down
113 changes: 113 additions & 0 deletions src/sparseml/export/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from pathlib import Path
from typing import Any, Callable, Optional, Union

from pydantic import BaseModel, Field, validator

from sparseml.pytorch.image_classification.utils.helpers import (
create_model as create_model_ic,
)
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)
from sparsezoo.utils.registry import RegistryMixin


__all__ = ["IntegrationHelperFunctions", "infer_integration"]


class Integrations(Enum):
"""
Holds the names of the available integrations.
"""

image_classification = "image-classification"


# TODO: Fold it into the functionalities of the RegistryMixin
# when the `resolve` method is generically implemented
def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.

:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
if is_image_classification_model(source_path):
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
return Integrations.image_classification.value
else:
raise ValueError(
f"Could not infer integration from source_path: {source_path}."
f"Please specify an argument `integration` from one of"
f"the available integrations: {list(Integrations)}."
)


class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"""
Registry that maps names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
description="A function that creates a (sparse) "
"PyTorch model from a source path and additional "
"arguments"
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)

# use validator to ensure that "create_model" outputs only the first output
@validator("create_model", pre=True)
def create_model_only_one_output(cls, v: Optional[Callable]) -> Optional[Callable]:
"""
Ensure that the create_model function only outputs
the first output - the model itself.
"""
if v is not None:
v = cls.wrap_to_return_first_output(v)
return v

@staticmethod
def wrap_to_return_first_output(func: Callable) -> Callable:
return lambda *args, **kwargs: func(*args, **kwargs)[0]


@IntegrationHelperFunctions.register(name=Integrations.image_classification.value)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
class ImageClassification(IntegrationHelperFunctions):
create_model: Any = Field(default=create_model_ic)
57 changes: 56 additions & 1 deletion src/sparseml/pytorch/image_classification/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from contextlib import nullcontext
from enum import Enum, auto, unique
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -344,17 +345,24 @@ def get_dataset_and_dataloader(
# Model creation Helpers
def create_model(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_path: str,
num_classes: int,
dataset_name: Optional[str] = None,
dataset_path: Optional[str] = None,
num_classes: Optional[int] = None,
recipe_path: Optional[str] = None,
arch_key: Optional[str] = None,
pretrained: Union[bool, str] = False,
pretrained_dataset: Optional[str] = None,
image_size: int = 224,
local_rank: int = -1,
**model_kwargs,
) -> Tuple[Module, str, str]:
"""
:param checkpoint_path: Path to the checkpoint to load. `zoo` for
downloading weights with respect to a SparseZoo recipe
:param dataset_name: The name of the dataset to use for model creation.
Defaults to `None`
:param dataset_path: The path to the dataset to use for model creation.
Defaults to `None`
:param num_classes: Integer representing the number of output classes
:param recipe_path: Path or SparseZoo stub to the recipe for downloading,
respective model. Defaults to `None`
Expand All @@ -364,11 +372,44 @@ def create_model(
False
:param pretrained_dataset: The dataset to used for pretraining. Defaults to
None
:param image_size: The image size to use for inference of num_classes
(in case num_classes is None) . Defaults to 224
:param local_rank: The local rank of the process. Defaults to -1
:param model_kwargs: Additional keyword arguments to pass to the model
:returns: A tuple containing the mode, the model's arch_key, and the
checkpoint path
"""
if num_classes is None:
# infer number of classes from the dataset
if dataset_name is None and dataset_path is None:
raise ValueError(
"To create a model either specify num_classes or "
"dataset_name and dataset_path (so that num_classes can be inferred)"
)
val_dataset, _ = get_dataset_and_dataloader(
dataset_name=dataset_name,
dataset_path=dataset_path,
batch_size=1,
image_size=image_size,
training=False,
loader_num_workers=1,
loader_pin_memory=False,
max_samples=1,
)

num_classes = infer_num_classes(
train_dataset=None,
val_dataset=val_dataset,
dataset=dataset_name,
model_kwargs=model_kwargs,
)
else:
if dataset_name is not None or dataset_path is not None:
warnings.warn(
"Both num_classes and dataset_name/dataset_path were provided. "
"Using num_classes and ignoring dataset_name/dataset_path"
)

with torch_distributed_zero_first(local_rank):
# only download once locally
if checkpoint_path and checkpoint_path.startswith("zoo"):
Expand Down Expand Up @@ -643,3 +684,17 @@ def _download_model_from_zoo_using_recipe(
Model(recipe_stub),
recipe_name=recipe_type,
)


def is_image_classification_model(source_path: Union[Path, str]) -> bool:
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
"""
:param source_path: The path to the model
:return: Whether the model is an image classification model or not
"""
try:
checkpoint = torch.load(os.path.join(source_path, "model.pth"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this expects directory, we should also allow a file for IC

arch_key = checkpoint.get("arch_key")
if arch_key:
return True
except Exception:
return False
13 changes: 13 additions & 0 deletions tests/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
22 changes: 22 additions & 0 deletions tests/sparseml/export/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.sparseml.export.registry import IntegrationHelperFunctions, Integrations


def test_integration_helper_functions():
image_classification = IntegrationHelperFunctions.load_from_registry(
Integrations.image_classification.value
)
assert image_classification.create_model
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from pathlib import Path

import pytest

from sparseml.pytorch.image_classification.utils.helpers import save_zoo_directory
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
save_zoo_directory,
)
from sparsezoo import Model


Expand Down Expand Up @@ -48,3 +52,13 @@ def test_save_zoo_directory(stub, tmp_path_factory):
assert new_zoo_model.validate(minimal_validation=True, validate_onnxruntime=False)
shutil.rmtree(path_to_training_outputs)
shutil.rmtree(save_dir)


@pytest.mark.parametrize(
"stub, is_image_classification",
[("zoo:efficientnet_v2-s-imagenet-base_quantized", True)],
)
def test_is_image_classification_model(stub, is_image_classification):
path_to_model = Model(stub).training.path
assert is_image_classification_model(path_to_model)
assert is_image_classification_model(Path(path_to_model))