Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor][Image Classification] create_dummy_input function #1880

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def pre_initialize_structure(
This will run the pre-initialize structure method for each modifier in the
session's lifecycle. This will also set the session's state to the
pre-initialized state. Takes care of cases when the model(s) structure
has been previosuly modified by a modifier.
has been previously modified by a modifier.

:param model: the model to pre-initialize the structure for
:param recipe: the recipe to use for the sparsification, can be a path to a
Expand Down
13 changes: 13 additions & 0 deletions src/sparseml/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
65 changes: 10 additions & 55 deletions src/sparseml/export.py → src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,19 @@

import logging
from pathlib import Path
from typing import Any, Callable, List, Optional, Union

from pydantic import BaseModel, Field
from typing import Any, List, Optional, Union

from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
infer_integration,
)
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparsezoo.utils.registry import RegistryMixin


_LOGGER = logging.getLogger(__name__)
AVAILABLE_DEPLOYMENT_TARGETS = ["deepsparse", "onnxruntime"]


class IntegrationHelperFunctions(BaseModel, RegistryMixin):
"""
Registry that maps integration names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path."
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)


def export(
source_path: Union[Path, str],
target_path: Union[Path, str],
Expand Down Expand Up @@ -132,9 +98,12 @@ def export(
IntegrationHelperFunctions.load_from_registry(integration)
)

model = helper_functions.create_model(source_path, device)
# for now, this code is not runnable, serves as a blueprint
model, auxiliary_items = helper_functions.create_model(
source_path, **kwargs # noqa: F821
) # type: ignore
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
sample_data = (
helper_functions.create_dummy_input(model, batch_size)
helper_functions.create_dummy_input(**auxiliary_items)
if sample_data is None
else sample_data
)
Expand Down Expand Up @@ -166,20 +135,6 @@ def export(
)


def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.
For example:
- for transformers model the integration
can be inferred from `config.json`
- for computer vision, the integration
can be inferred from the model architecture (`arch_key`)
:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
raise NotImplementedError


def validate_correctness(deployment_path: Union[Path, str]):
"""
Validate the correctness of the exported model.
Expand Down
105 changes: 105 additions & 0 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union

from pydantic import BaseModel, Field

from sparsezoo.utils.registry import RegistryMixin


__all__ = ["IntegrationHelperFunctions", "infer_integration"]


class Integrations(Enum):
"""
Holds the names of the available integrations.
"""

image_classification = "image-classification"


class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"""
Registry that maps names to helper functions
for creation/export/manipulation of models for a specific
integration.
"""

create_model: Optional[
Callable[
[Union[str, Path], Optional[Dict[str, Any]]][
"torch.nn.Module", Optional[Dict[str, Any]] # noqa F821 # noqa F821
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
]
]
] = Field(
description="A function that takes: "
"- a source path to a PyTorch model "
"- (optionally) a dictionary of additional arguments"
"and returns: "
"- a (sparse) PyTorch model "
"- (optionally) a dictionary of additional arguments"
)
create_dummy_input: Optional[
Callable[Any]["torch.Tensor"] # noqa F821
] = Field( # noqa: F82
description="A function that takes: "
"- a dictionary of arguments"
"and returns: "
"- a dummy input for the model (a torch.Tensor) "
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."
)
create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
"with the appropriate structure."
)


def infer_integration(source_path: Union[Path, str]) -> str:
"""
Infer the integration to use for exporting the model from the source_path.

:param source_path: The path to the PyTorch model to export.
:return: The name of the integration to use for exporting the model.
"""
from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)

if is_image_classification_model(source_path):
# import to register the image_classification integration helper functions
import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401

return Integrations.image_classification.value
else:
raise ValueError(
f"Could not infer integration from source_path: {source_path}."
f"Please specify an argument `integration` from one of"
f"the available integrations: {list(Integrations)}."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from pydantic import Field

from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
)
from src.sparseml.pytorch.image_classification.utils.helpers import (
create_model as create_image_classification_model,
)


def create_model(
source_path: Union[Path, str], **kwargs
) -> Tuple[torch.nn.Module, Dict[str, Any]]:
"""
A contract to create a model from a source path

:param source_path: The path to the model
:param kwargs: Additional kwargs to pass to the model creation function
:return: A tuple of the
- torch model
- additional dictionary of useful objects created during model creation
"""
model, *_, validation_loader = create_image_classification_model(
checkpoint_path=source_path, **kwargs
)
return model, dict(validation_loader=validation_loader)


def create_dummy_input(
validation_loader: Optional[torch.utils.data.DataLoader] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we expect this optional value to flow from the create model function? thinking we need to accept **kwargs in this signature to avoid edge cases for extra kwargs (ie arch key that we'll need)

Copy link
Contributor Author

@dbogunowicz dbogunowicz Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the flow is quite elegant and well-designed:

# def [helper_functions].create_model(...) -> Tuple[torch.nn.Module, Dict[str, Any]]

model: torch.nn.Module, auxiliary_items: Optional[Dict[str, Any]] = helper_functions.create_model(...)

# in case of IC, auxiliary_items is either dict(validation_loader=validation_loader) or dict(validation_loader=None), so this should be picked up by `create_dummy_input`

Maybe i'd call the variable more generically, i.e. data_loader instead of validation_loader

image_size: Optional[int] = 224,
) -> torch.Tensor:
"""
A contract to create a dummy input for a model

:param validation_loader: The validation loader to get a batch from.
If None, a fake batch will be created
:param image_size: The image size to use for the dummy input. Defaults to 224
:return: The dummy input as a torch tensor
"""

if not validation_loader:
# create fake data for export
validation_loader = [[torch.randn(1, 3, image_size, image_size)]]
return next(iter(validation_loader))[0]


@IntegrationHelperFunctions.register(name=Integrations.image_classification.value)
class ImageClassification(IntegrationHelperFunctions):
create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field(
default=create_model
)
create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input)
Loading