Skip to content

Commit

Permalink
[Export Refactor][Image Classification] create_model function (#1878)
Browse files Browse the repository at this point in the history
* initial commit

* looking good, time to cleanup

* Delete src/sparseml/export/helpers.py

* Delete tests/sparseml/export/test_helpers.py

* ready for review

* improve design

* tests pass

* reuse _validate_dataset_num_classes
  • Loading branch information
dbogunowicz committed Dec 11, 2023
1 parent 497c008 commit f2f54da
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 56 deletions.
2 changes: 1 addition & 1 deletion src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def pre_initialize_structure(
This will run the pre-initialize structure method for each modifier in the
session's lifecycle. This will also set the session's state to the
pre-initialized state. Takes care of cases when the model(s) structure
has been previosuly modified by a modifier.
has been previously modified by a modifier.
:param model: the model to pre-initialize the structure for
:param recipe: the recipe to use for the sparsification, can be a path to a
Expand Down
13 changes: 13 additions & 0 deletions src/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
58 changes: 5 additions & 53 deletions src/sparseml/export.py → src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,19 @@

import logging
from pathlib import Path
from typing import Any, Callable, List, Optional, Union

from pydantic import BaseModel, Field
from typing import Any, List, Optional, Union

from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
infer_integration,
)
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparsezoo.utils.registry import RegistryMixin


_LOGGER = logging.getLogger(__name__)
AVAILABLE_DEPLOYMENT_TARGETS = ["deepsparse", "onnxruntime"]


class IntegrationHelperFunctions(BaseModel, RegistryMixin):
"""
Registry that maps integration names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path."
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)


def export(
source_path: Union[Path, str],
target_path: Union[Path, str],
Expand Down Expand Up @@ -166,20 +132,6 @@ def export(
)


def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.
For example:
- for transformers model the integration
can be inferred from `config.json`
- for computer vision, the integration
can be inferred from the model architecture (`arch_key`)
:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
raise NotImplementedError


def validate_correctness(deployment_path: Union[Path, str]):
"""
Validate the correctness of the exported model.
Expand Down
92 changes: 92 additions & 0 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from pathlib import Path
from typing import Callable, Optional, Union

from pydantic import BaseModel, Field

from sparsezoo.utils.registry import RegistryMixin


__all__ = ["IntegrationHelperFunctions", "infer_integration"]


class Integrations(Enum):
"""
Holds the names of the available integrations.
"""

image_classification = "image-classification"


class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"""
Registry that maps names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path and additional "
"arguments"
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)


def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.
:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)

if is_image_classification_model(source_path):
# import to register the image_classification integration helper functions
import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401

return Integrations.image_classification.value
else:
raise ValueError(
f"Could not infer integration from source_path: {source_path}."
f"Please specify an argument `integration` from one of"
f"the available integrations: {list(Integrations)}."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Any, Union

import torch
from pydantic import Field

from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
)
from src.sparseml.pytorch.image_classification.utils.helpers import (
create_model as create_image_classification_model,
)


def create_model(source_path: Union[Path, str], **kwargs) -> torch.nn.Module:
"""
A contract to create a model from a source path
:param source_path: The path to the model
:return: The torch model
"""
model, *_ = create_image_classification_model(checkpoint_path=source_path, **kwargs)
return model


@IntegrationHelperFunctions.register(name=Integrations.image_classification.value)
class ImageClassification(IntegrationHelperFunctions):
create_model: Any = Field(default=create_model)
83 changes: 82 additions & 1 deletion src/sparseml/pytorch/image_classification/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from contextlib import nullcontext
from enum import Enum, auto, unique
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -47,6 +48,7 @@
default_device,
download_framework_model_by_recipe_type,
early_stop_data_loader,
load_model,
model_to_device,
set_deterministic_seeds,
torch_distributed_zero_first,
Expand Down Expand Up @@ -344,17 +346,25 @@ def get_dataset_and_dataloader(
# Model creation Helpers
def create_model(
checkpoint_path: str,
num_classes: int,
dataset_name: Optional[str] = None,
dataset_path: Optional[str] = None,
num_classes: Optional[int] = None,
recipe_path: Optional[str] = None,
arch_key: Optional[str] = None,
pretrained: Union[bool, str] = False,
pretrained_dataset: Optional[str] = None,
one_shot: Optional[str] = None,
image_size: int = 224,
local_rank: int = -1,
**model_kwargs,
) -> Tuple[Module, str, str]:
"""
:param checkpoint_path: Path to the checkpoint to load. `zoo` for
downloading weights with respect to a SparseZoo recipe
:param dataset_name: The name of the dataset to use for model creation.
Defaults to `None`
:param dataset_path: The path to the dataset to use for model creation.
Defaults to `None`
:param num_classes: Integer representing the number of output classes
:param recipe_path: Path or SparseZoo stub to the recipe for downloading,
respective model. Defaults to `None`
Expand All @@ -364,11 +374,38 @@ def create_model(
False
:param pretrained_dataset: The dataset to used for pretraining. Defaults to
None
:param one_shot: The recipe to be applied in one-shot manner,
before exporting. Defaults to None
:param image_size: The image size to use for inference of num_classes
(in case num_classes is None) . Defaults to 224
:param local_rank: The local rank of the process. Defaults to -1
:param model_kwargs: Additional keyword arguments to pass to the model
:returns: A tuple containing the mode, the model's arch_key, and the
checkpoint path
"""
_validate_dataset_num_classes(
dataset_path=dataset_path, dataset=dataset_name, num_classes=num_classes
)

if num_classes is None:
val_dataset, _ = get_dataset_and_dataloader(
dataset_name=dataset_name,
dataset_path=dataset_path,
batch_size=1,
image_size=image_size,
training=False,
loader_num_workers=1,
loader_pin_memory=False,
max_samples=1,
)

num_classes = infer_num_classes(
train_dataset=None,
val_dataset=val_dataset,
dataset=dataset_name,
model_kwargs=model_kwargs,
)

with torch_distributed_zero_first(local_rank):
# only download once locally
if checkpoint_path and checkpoint_path.startswith("zoo"):
Expand Down Expand Up @@ -399,6 +436,16 @@ def create_model(
else:
model, arch_key = result

# TODO: discuss how this is related to the above application of recipes
if recipe_path is not None:
# TODO: replace this with a new manager introduced by @satrat
ScheduledModifierManager.from_yaml(recipe_path).apply_structure(model)
if checkpoint_path:
load_model(checkpoint_path, model, strict=True)

if one_shot is not None:
ScheduledModifierManager.from_yaml(file_path=one_shot).apply(module=model)

return model, arch_key, checkpoint_path


Expand Down Expand Up @@ -643,3 +690,37 @@ def _download_model_from_zoo_using_recipe(
Model(recipe_stub),
recipe_name=recipe_type,
)


def is_image_classification_model(source_path: Union[Path, str]) -> bool:
"""
:param source_path: The path to the model
:return: Whether the model is an image classification model or not
"""
if not os.isfile(source_path):
checkpoint_path = os.path.join(source_path, "model.pth")
else:
checkpoint_path = source_path
try:
checkpoint = torch.load(checkpoint_path)
arch_key = checkpoint.get("arch_key")
if arch_key:
return True
except Exception:
return False


def _validate_dataset_num_classes(
dataset: str,
dataset_path: str,
num_classes: int,
):
if dataset and not dataset_path:
raise ValueError(f"found dataset {dataset} but dataset_path not specified")
if dataset_path and not dataset:
raise ValueError(f"found dataset_path {dataset_path} but dataset not specified")
if num_classes is None and (not dataset or not dataset_path):
raise ValueError(
"If num_classes is not provided, both dataset and dataset_path must be "
"set to infer num_classes"
)
13 changes: 13 additions & 0 deletions tests/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit f2f54da

Please sign in to comment.