Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor][Image Classification] create_model function #1878

13 changes: 13 additions & 0 deletions src/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
61 changes: 19 additions & 42 deletions src/sparseml/export.py → src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,26 @@
# limitations under the License.

import logging
from enum import Enum
from pathlib import Path
from typing import Any, Callable, List, Optional, Union
from typing import Any, List, Optional, Union

from pydantic import BaseModel, Field
from sparseml.export.registry import IntegrationHelperFunctions

# TODO: Fold it into the functionalities of the RegistryMixin
# when the resolve method is implemented
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparsezoo.utils.registry import RegistryMixin


_LOGGER = logging.getLogger(__name__)
AVAILABLE_DEPLOYMENT_TARGETS = ["deepsparse", "onnxruntime"]


class IntegrationHelperFunctions(BaseModel, RegistryMixin):
"""
Registry that maps integration names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path."
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)
class Integrations(Enum):
image_classification = "image_classification"


def export(
Expand Down Expand Up @@ -169,15 +143,18 @@ def export(
def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.
For example:
- for transformers model the integration
can be inferred from `config.json`
- for computer vision, the integration
can be inferred from the model architecture (`arch_key`)

:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
raise NotImplementedError
if is_image_classification_model(source_path):
return Integrations.image_classification.value
else:
raise ValueError(
f"Could not infer integration from source_path: {source_path}."
f"Please specify an argument `integration` from one of"
f"the available integrations: {list(Integrations)}."
)


def validate_correctness(deployment_path: Union[Path, str]):
Expand Down
73 changes: 73 additions & 0 deletions src/sparseml/export/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional

from pydantic import BaseModel, Field, validator

from sparseml.pytorch.image_classification.utils.helpers import (
create_model as create_model_ic,
)
from sparsezoo.utils.registry import RegistryMixin


class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"""
Registry that maps integration names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
description="A function that creates a (sparse) "
"PyTorch model from a source path."
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)

# use validator to ensure that "create_model" outputs only the first output
@validator()
def create_model_only_one_output(cls, v):
if v is not None:
v = cls.wrap_to_return_first_output(v)
return v

@staticmethod
def wrap_to_return_first_output(func):
return lambda *args, **kwargs: func(*args, **kwargs)[0]


@IntegrationHelperFunctions.register()
class ImageClassification(IntegrationHelperFunctions):
create_model: Any = Field(default=create_model_ic)
57 changes: 56 additions & 1 deletion src/sparseml/pytorch/image_classification/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from contextlib import nullcontext
from enum import Enum, auto, unique
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -344,17 +345,24 @@ def get_dataset_and_dataloader(
# Model creation Helpers
def create_model(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_path: str,
num_classes: int,
dataset_name: Optional[str] = None,
dataset_path: Optional[str] = None,
num_classes: Optional[int] = None,
recipe_path: Optional[str] = None,
arch_key: Optional[str] = None,
pretrained: Union[bool, str] = False,
pretrained_dataset: Optional[str] = None,
image_size: int = 224,
local_rank: int = -1,
**model_kwargs,
) -> Tuple[Module, str, str]:
"""
:param checkpoint_path: Path to the checkpoint to load. `zoo` for
downloading weights with respect to a SparseZoo recipe
:param dataset_name: The name of the dataset to use for model creation.
Defaults to `None`
:param dataset_path: The path to the dataset to use for model creation.
Defaults to `None`
:param num_classes: Integer representing the number of output classes
:param recipe_path: Path or SparseZoo stub to the recipe for downloading,
respective model. Defaults to `None`
Expand All @@ -364,11 +372,44 @@ def create_model(
False
:param pretrained_dataset: The dataset to used for pretraining. Defaults to
None
:param image_size: The image size to use for inference of num_classes
(in case num_classes is None) . Defaults to 224
:param local_rank: The local rank of the process. Defaults to -1
:param model_kwargs: Additional keyword arguments to pass to the model
:returns: A tuple containing the mode, the model's arch_key, and the
checkpoint path
"""
if num_classes is None:
# infer number of classes from the dataset
if dataset_name is None and dataset_path is None:
raise ValueError(
"To create a model either specify num_classes or "
"dataset_name and dataset_path (so that num_classes can be inferred)"
)
val_dataset, _ = get_dataset_and_dataloader(
dataset_name=dataset_name,
dataset_path=dataset_path,
batch_size=1,
image_size=image_size,
training=False,
loader_num_workers=1,
loader_pin_memory=False,
max_samples=1,
)

num_classes = infer_num_classes(
train_dataset=None,
val_dataset=val_dataset,
dataset=dataset_name,
model_kwargs=model_kwargs,
)
else:
if dataset_name is not None or dataset_path is not None:
warnings.warn(
"Both num_classes and dataset_name/dataset_path were provided. "
"Using num_classes and ignoring dataset_name/dataset_path"
)

with torch_distributed_zero_first(local_rank):
# only download once locally
if checkpoint_path and checkpoint_path.startswith("zoo"):
Expand Down Expand Up @@ -643,3 +684,17 @@ def _download_model_from_zoo_using_recipe(
Model(recipe_stub),
recipe_name=recipe_type,
)


def is_image_classification_model(source_path: Union[Path, str]) -> bool:
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
"""
:param source_path: The path to the model
:return: Whether the model is an image classification model or not
"""
try:
checkpoint = torch.load(os.path.join(source_path, "model.pth"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this expects directory, we should also allow a file for IC

arch_key = checkpoint.get("arch_key")
if arch_key:
return True
except Exception:
return False
13 changes: 13 additions & 0 deletions tests/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
22 changes: 22 additions & 0 deletions tests/sparseml/export/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.sparseml.export.registry import IntegrationHelperFunctions


def test_integration_helper_functions():
image_classification = IntegrationHelperFunctions.load_from_registry(
"ImageClassification"
)
assert image_classification.create_model
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from pathlib import Path

import pytest

from sparseml.pytorch.image_classification.utils.helpers import save_zoo_directory
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
save_zoo_directory,
)
from sparsezoo import Model


Expand Down Expand Up @@ -48,3 +52,13 @@ def test_save_zoo_directory(stub, tmp_path_factory):
assert new_zoo_model.validate(minimal_validation=True, validate_onnxruntime=False)
shutil.rmtree(path_to_training_outputs)
shutil.rmtree(save_dir)


@pytest.mark.parametrize(
"stub, is_image_classification",
[("zoo:efficientnet_v2-s-imagenet-base_quantized", True)],
)
def test_is_image_classification_model(stub, is_image_classification):
path_to_model = Model(stub).training.path
assert is_image_classification_model(path_to_model)
assert is_image_classification_model(Path(path_to_model))