Skip to content

Commit

Permalink
[Export Refactor][Image Classification] create_dummy_input function (
Browse files Browse the repository at this point in the history
…#1880)

* initial commit

* looking good, time to cleanup

* Delete src/sparseml/export/helpers.py

* Delete tests/sparseml/export/test_helpers.py

* ready for review

* improve design

* tests pass

* reuse _validate_dataset_num_classes

* initial commit

* Update src/sparseml/pytorch/image_classification/integration_helper_functions.py

* Update src/sparseml/pytorch/image_classification/integration_helper_functions.py

* ready for review

* Update src/sparseml/export/export.py

* Update src/sparseml/integration_helper_functions.py
  • Loading branch information
dbogunowicz committed Dec 11, 2023
1 parent f2f54da commit bac0802
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 17 deletions.
7 changes: 5 additions & 2 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ def export(
IntegrationHelperFunctions.load_from_registry(integration)
)

model = helper_functions.create_model(source_path, device)
# for now, this code is not runnable, serves as a blueprint
model, auxiliary_items = helper_functions.create_model(
source_path, **kwargs # noqa: F821
)
sample_data = (
helper_functions.create_dummy_input(model, batch_size)
helper_functions.create_dummy_input(**auxiliary_items)
if sample_data is None
else sample_data
)
Expand Down
29 changes: 21 additions & 8 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from enum import Enum
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from pydantic import BaseModel, Field

Expand All @@ -39,14 +39,27 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path and additional "
"arguments"
create_model: Optional[
Callable[
[Union[str, Path], Optional[Dict[str, Any]]][
"torch.nn.Module", Optional[Dict[str, Any]] # noqa F821
]
]
] = Field(
description="A function that takes: "
"- a source path to a PyTorch model "
"- (optionally) a dictionary of additional arguments"
"and returns: "
"- a (sparse) PyTorch model "
"- (optionally) a dictionary of additional arguments"
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
create_dummy_input: Optional[
Callable[Any]["torch.Tensor"] # noqa F821
] = Field( # noqa: F82
description="A function that takes: "
"- a dictionary of arguments"
"and returns: "
"- a dummy input for the model (a torch.Tensor) "
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from pathlib import Path
from typing import Any, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from pydantic import Field
Expand All @@ -27,17 +27,46 @@
)


def create_model(source_path: Union[Path, str], **kwargs) -> torch.nn.Module:
def create_model(
source_path: Union[Path, str], **kwargs
) -> Tuple[torch.nn.Module, Dict[str, Any]]:
"""
A contract to create a model from a source path
:param source_path: The path to the model
:return: The torch model
:param kwargs: Additional kwargs to pass to the model creation function
:return: A tuple of the
- torch model
- additional dictionary of useful objects created during model creation
"""
model, *_ = create_image_classification_model(checkpoint_path=source_path, **kwargs)
return model
model, *_, validation_loader = create_image_classification_model(
checkpoint_path=source_path, **kwargs
)
return model, dict(validation_loader=validation_loader)


def create_dummy_input(
validation_loader: Optional[torch.utils.data.DataLoader] = None,
image_size: Optional[int] = 224,
) -> torch.Tensor:
"""
A contract to create a dummy input for a model
:param validation_loader: The validation loader to get a batch from.
If None, a fake batch will be created
:param image_size: The image size to use for the dummy input. Defaults to 224
:return: The dummy input as a torch tensor
"""

if not validation_loader:
# create fake data for export
validation_loader = [[torch.randn(1, 3, image_size, image_size)]]
return next(iter(validation_loader))[0]


@IntegrationHelperFunctions.register(name=Integrations.image_classification.value)
class ImageClassification(IntegrationHelperFunctions):
create_model: Any = Field(default=create_model)
create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field(
default=create_model
)
create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input)
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/image_classification/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def create_model(
if one_shot is not None:
ScheduledModifierManager.from_yaml(file_path=one_shot).apply(module=model)

return model, arch_key, checkpoint_path
return model, arch_key, checkpoint_path, val_dataset


def infer_num_classes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ def test_integration_helper_functions():
Integrations.image_classification.value
)
assert image_classification.create_model
assert image_classification.create_dummy_input

0 comments on commit bac0802

Please sign in to comment.