Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor][Image Classification] create_model function #1878

Merged
2 changes: 1 addition & 1 deletion src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def pre_initialize_structure(
This will run the pre-initialize structure method for each modifier in the
session's lifecycle. This will also set the session's state to the
pre-initialized state. Takes care of cases when the model(s) structure
has been previosuly modified by a modifier.
has been previously modified by a modifier.

:param model: the model to pre-initialize the structure for
:param recipe: the recipe to use for the sparsification, can be a path to a
Expand Down
13 changes: 13 additions & 0 deletions src/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
58 changes: 5 additions & 53 deletions src/sparseml/export.py → src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,19 @@

import logging
from pathlib import Path
from typing import Any, Callable, List, Optional, Union

from pydantic import BaseModel, Field
from typing import Any, List, Optional, Union

from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
infer_integration,
)
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparsezoo.utils.registry import RegistryMixin


_LOGGER = logging.getLogger(__name__)
AVAILABLE_DEPLOYMENT_TARGETS = ["deepsparse", "onnxruntime"]


class IntegrationHelperFunctions(BaseModel, RegistryMixin):
"""
Registry that maps integration names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path."
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)


def export(
source_path: Union[Path, str],
target_path: Union[Path, str],
Expand Down Expand Up @@ -166,20 +132,6 @@ def export(
)


def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.
For example:
- for transformers model the integration
can be inferred from `config.json`
- for computer vision, the integration
can be inferred from the model architecture (`arch_key`)
:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
raise NotImplementedError


def validate_correctness(deployment_path: Union[Path, str]):
"""
Validate the correctness of the exported model.
Expand Down
92 changes: 92 additions & 0 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from pathlib import Path
from typing import Callable, Optional, Union

from pydantic import BaseModel, Field

from sparsezoo.utils.registry import RegistryMixin


__all__ = ["IntegrationHelperFunctions", "infer_integration"]


class Integrations(Enum):
"""
Holds the names of the available integrations.
"""

image_classification = "image-classification"


class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"""
Registry that maps names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path and additional "
"arguments"
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)

dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved

def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.

:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)

if is_image_classification_model(source_path):
# import to register the image_classification integration helper functions
import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401

return Integrations.image_classification.value
else:
raise ValueError(
f"Could not infer integration from source_path: {source_path}."
f"Please specify an argument `integration` from one of"
f"the available integrations: {list(Integrations)}."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Any, Union

import torch
from pydantic import Field

from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
)
from src.sparseml.pytorch.image_classification.utils.helpers import (
create_model as create_image_classification_model,
)


def create_model(source_path: Union[Path, str], **kwargs) -> torch.nn.Module:
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
"""
A contract to create a model from a source path

:param source_path: The path to the model
:return: The torch model
"""
model, *_ = create_image_classification_model(checkpoint_path=source_path, **kwargs)
return model


@IntegrationHelperFunctions.register(name=Integrations.image_classification.value)
class ImageClassification(IntegrationHelperFunctions):
create_model: Any = Field(default=create_model)
83 changes: 82 additions & 1 deletion src/sparseml/pytorch/image_classification/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from contextlib import nullcontext
from enum import Enum, auto, unique
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -47,6 +48,7 @@
default_device,
download_framework_model_by_recipe_type,
early_stop_data_loader,
load_model,
model_to_device,
set_deterministic_seeds,
torch_distributed_zero_first,
Expand Down Expand Up @@ -344,17 +346,25 @@ def get_dataset_and_dataloader(
# Model creation Helpers
def create_model(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_path: str,
num_classes: int,
dataset_name: Optional[str] = None,
dataset_path: Optional[str] = None,
num_classes: Optional[int] = None,
recipe_path: Optional[str] = None,
arch_key: Optional[str] = None,
pretrained: Union[bool, str] = False,
pretrained_dataset: Optional[str] = None,
one_shot: Optional[str] = None,
image_size: int = 224,
local_rank: int = -1,
**model_kwargs,
) -> Tuple[Module, str, str]:
"""
:param checkpoint_path: Path to the checkpoint to load. `zoo` for
downloading weights with respect to a SparseZoo recipe
:param dataset_name: The name of the dataset to use for model creation.
Defaults to `None`
:param dataset_path: The path to the dataset to use for model creation.
Defaults to `None`
:param num_classes: Integer representing the number of output classes
:param recipe_path: Path or SparseZoo stub to the recipe for downloading,
respective model. Defaults to `None`
Expand All @@ -364,11 +374,38 @@ def create_model(
False
:param pretrained_dataset: The dataset to used for pretraining. Defaults to
None
:param one_shot: The recipe to be applied in one-shot manner,
before exporting. Defaults to None
:param image_size: The image size to use for inference of num_classes
(in case num_classes is None) . Defaults to 224
:param local_rank: The local rank of the process. Defaults to -1
:param model_kwargs: Additional keyword arguments to pass to the model
:returns: A tuple containing the mode, the model's arch_key, and the
checkpoint path
"""
_validate_dataset_num_classes(
dataset_path=dataset_path, dataset=dataset_name, num_classes=num_classes
)

if num_classes is None:
val_dataset, _ = get_dataset_and_dataloader(
dataset_name=dataset_name,
dataset_path=dataset_path,
batch_size=1,
image_size=image_size,
training=False,
loader_num_workers=1,
loader_pin_memory=False,
max_samples=1,
)

num_classes = infer_num_classes(
train_dataset=None,
val_dataset=val_dataset,
dataset=dataset_name,
model_kwargs=model_kwargs,
)

with torch_distributed_zero_first(local_rank):
# only download once locally
if checkpoint_path and checkpoint_path.startswith("zoo"):
Expand Down Expand Up @@ -399,6 +436,16 @@ def create_model(
else:
model, arch_key = result

# TODO: discuss how this is related to the above application of recipes
if recipe_path is not None:
# TODO: replace this with a new manager introduced by @satrat
ScheduledModifierManager.from_yaml(recipe_path).apply_structure(model)
if checkpoint_path:
load_model(checkpoint_path, model, strict=True)

if one_shot is not None:
ScheduledModifierManager.from_yaml(file_path=one_shot).apply(module=model)

return model, arch_key, checkpoint_path


Expand Down Expand Up @@ -643,3 +690,37 @@ def _download_model_from_zoo_using_recipe(
Model(recipe_stub),
recipe_name=recipe_type,
)


def is_image_classification_model(source_path: Union[Path, str]) -> bool:
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
"""
:param source_path: The path to the model
:return: Whether the model is an image classification model or not
"""
if not os.isfile(source_path):
checkpoint_path = os.path.join(source_path, "model.pth")
else:
checkpoint_path = source_path
try:
checkpoint = torch.load(checkpoint_path)
arch_key = checkpoint.get("arch_key")
if arch_key:
return True
except Exception:
return False


def _validate_dataset_num_classes(
dataset: str,
dataset_path: str,
num_classes: int,
):
if dataset and not dataset_path:
raise ValueError(f"found dataset {dataset} but dataset_path not specified")
if dataset_path and not dataset:
raise ValueError(f"found dataset_path {dataset_path} but dataset not specified")
if num_classes is None and (not dataset or not dataset_path):
raise ValueError(
"If num_classes is not provided, both dataset and dataset_path must be "
"set to infer num_classes"
)
13 changes: 13 additions & 0 deletions tests/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading