Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor][Image Classification] create_dummy_input function #1880

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ def export(
IntegrationHelperFunctions.load_from_registry(integration)
)

model = helper_functions.create_model(source_path, device)
# for now, this code is not runnable, serves as a blueprint
model, auxiliary_items = helper_functions.create_model(
source_path, **kwargs # noqa: F821
)
sample_data = (
helper_functions.create_dummy_input(model, batch_size)
helper_functions.create_dummy_input(**auxiliary_items)
if sample_data is None
else sample_data
)
Expand Down
29 changes: 21 additions & 8 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from enum import Enum
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

from pydantic import BaseModel, Field

Expand All @@ -39,14 +39,27 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
integration.
"""

create_model: Optional[Callable] = Field(
description="A function that creates a (sparse) "
"PyTorch model from a source path and additional "
"arguments"
create_model: Optional[
Callable[
[Union[str, Path], Optional[Dict[str, Any]]][
"torch.nn.Module", Optional[Dict[str, Any]] # noqa F821
]
]
] = Field(
description="A function that takes: "
"- a source path to a PyTorch model "
"- (optionally) a dictionary of additional arguments"
"and returns: "
"- a (sparse) PyTorch model "
"- (optionally) a dictionary of additional arguments"
)
create_dummy_input: Optional[Callable] = Field(
description="A function that creates a dummy input "
"given a (sparse) PyTorch model."
create_dummy_input: Optional[
Callable[Any]["torch.Tensor"] # noqa F821
] = Field( # noqa: F82
description="A function that takes: "
"- a dictionary of arguments"
"and returns: "
"- a dummy input for the model (a torch.Tensor) "
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from pathlib import Path
from typing import Any, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from pydantic import Field
Expand All @@ -27,17 +27,46 @@
)


def create_model(source_path: Union[Path, str], **kwargs) -> torch.nn.Module:
def create_model(
source_path: Union[Path, str], **kwargs
) -> Tuple[torch.nn.Module, Dict[str, Any]]:
"""
A contract to create a model from a source path

:param source_path: The path to the model
:return: The torch model
:param kwargs: Additional kwargs to pass to the model creation function
:return: A tuple of the
- torch model
- additional dictionary of useful objects created during model creation
"""
model, *_ = create_image_classification_model(checkpoint_path=source_path, **kwargs)
return model
model, *_, validation_loader = create_image_classification_model(
checkpoint_path=source_path, **kwargs
)
return model, dict(validation_loader=validation_loader)


def create_dummy_input(
validation_loader: Optional[torch.utils.data.DataLoader] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we expect this optional value to flow from the create model function? thinking we need to accept **kwargs in this signature to avoid edge cases for extra kwargs (ie arch key that we'll need)

Copy link
Contributor Author

@dbogunowicz dbogunowicz Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the flow is quite elegant and well-designed:

# def [helper_functions].create_model(...) -> Tuple[torch.nn.Module, Dict[str, Any]]

model: torch.nn.Module, auxiliary_items: Optional[Dict[str, Any]] = helper_functions.create_model(...)

# in case of IC, auxiliary_items is either dict(validation_loader=validation_loader) or dict(validation_loader=None), so this should be picked up by `create_dummy_input`

Maybe i'd call the variable more generically, i.e. data_loader instead of validation_loader

image_size: Optional[int] = 224,
) -> torch.Tensor:
"""
A contract to create a dummy input for a model

:param validation_loader: The validation loader to get a batch from.
If None, a fake batch will be created
:param image_size: The image size to use for the dummy input. Defaults to 224
:return: The dummy input as a torch tensor
"""

if not validation_loader:
# create fake data for export
validation_loader = [[torch.randn(1, 3, image_size, image_size)]]
return next(iter(validation_loader))[0]


@IntegrationHelperFunctions.register(name=Integrations.image_classification.value)
class ImageClassification(IntegrationHelperFunctions):
create_model: Any = Field(default=create_model)
create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field(
default=create_model
)
create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input)
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/image_classification/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def create_model(
if one_shot is not None:
ScheduledModifierManager.from_yaml(file_path=one_shot).apply(module=model)

return model, arch_key, checkpoint_path
return model, arch_key, checkpoint_path, val_dataset


def infer_num_classes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ def test_integration_helper_functions():
Integrations.image_classification.value
)
assert image_classification.create_model
assert image_classification.create_dummy_input