Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor][Image Classification] export_sample_inputs_outputs function #1888

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6d8991c
initial commit
dbogunowicz Dec 5, 2023
059ceab
looking good, time to cleanup
dbogunowicz Dec 6, 2023
b390c2e
Delete src/sparseml/export/helpers.py
dbogunowicz Dec 6, 2023
c2c8444
Delete tests/sparseml/export/test_helpers.py
dbogunowicz Dec 6, 2023
6ce6ba5
ready for review
dbogunowicz Dec 6, 2023
6f3e5e7
Merge branch 'feature/damian/create_model_ic' of github.com:neuralmag…
dbogunowicz Dec 6, 2023
5dfbdcd
improve design
dbogunowicz Dec 6, 2023
042c193
tests pass
dbogunowicz Dec 6, 2023
29cfa1d
reuse _validate_dataset_num_classes
dbogunowicz Dec 6, 2023
ab73aec
initial commit
dbogunowicz Dec 6, 2023
f628532
Update src/sparseml/pytorch/image_classification/integration_helper_f…
dbogunowicz Dec 6, 2023
b93b634
Update src/sparseml/pytorch/image_classification/integration_helper_f…
dbogunowicz Dec 6, 2023
e7606cd
ready for review
dbogunowicz Dec 7, 2023
ea9cb61
Update src/sparseml/export/export.py
dbogunowicz Dec 7, 2023
9572e0b
Update src/sparseml/integration_helper_functions.py
dbogunowicz Dec 7, 2023
0229deb
initial commit
dbogunowicz Dec 7, 2023
a8c1b68
fixes
dbogunowicz Dec 7, 2023
2379354
ready for review
dbogunowicz Dec 7, 2023
741fb12
nit
dbogunowicz Dec 7, 2023
ebdeb9f
add return
dbogunowicz Dec 7, 2023
8b2fca0
initial commit
dbogunowicz Dec 7, 2023
ff52598
initial commit
dbogunowicz Dec 8, 2023
50ca948
PR comments
dbogunowicz Dec 11, 2023
2f71f7b
beautification
dbogunowicz Dec 11, 2023
cabc17e
Merge remote-tracking branch 'origin/feature/damian/feature_branch_ex…
dbogunowicz Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, List, Optional, Union

from sparseml.export.helpers import apply_optimizations
from sparseml.export_data import export_data_samples
from sparseml.exporters import ExportTargets
from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Expand All @@ -42,7 +43,7 @@ def export(
single_graph_file: bool = True,
graph_optimizations: Union[str, List[str], None] = "all",
validate_correctness: bool = False,
export_sample_inputs_outputs: bool = False,
num_export_samples: int = 0,
deployment_directory_name: str = "deployment",
device: str = "auto",
):
Expand Down Expand Up @@ -81,8 +82,8 @@ def export(
to the exported model. Defaults to 'all'.
:param validate_correctness: Whether to validate the correctness
of the exported model. Defaults to False.
:param export_sample_inputs_outputs: Whether to export sample
inputs and outputs for the exported model.Defaults to False.
:param num_export_samples: The number of samples to export for
the exported model. Defaults to 0.
:param deployment_directory_name: The name of the deployment
directory to create for the exported model. Thus, the exported
model will be saved to `target_path/deployment_directory_name`.
Expand Down Expand Up @@ -123,18 +124,38 @@ def export(
single_graph_file=single_graph_file,
)

if export_sample_inputs_outputs:
helper_functions.export_sample_inputs_outputs(model, target_path)
if num_export_samples:
data_loader = auxiliary_items.get("validation_loader")
if data_loader is None:
raise ValueError(
"To export sample inputs/outputs a data loader is needed."
"To enable the export, provide a `validatation_loader` "
"as a part of `auxiliary_items` output of the `create_model` function."
)
(
input_samples,
output_samples,
label_samples,
) = helper_functions.create_data_samples(
num_samples=num_export_samples, data_loader=data_loader, model=model
)
export_data_samples(
input_samples=input_samples,
output_samples=output_samples,
label_samples=label_samples,
target_path=target_path,
as_tar=True,
)

deployment_path = helper_functions.create_deployment_folder(
source_path, target_path, deployment_directory_name
)

if validate_correctness:
if not export_sample_inputs_outputs:
if not num_export_samples:
raise ValueError(
"To validate correctness sample inputs/outputs are needed."
"To enable the validation, set `export_sample_inputs_outputs`"
"To enable the validation, set `num_export_samples`"
"to True"
)
validate_correctness(deployment_path)
Expand Down
138 changes: 138 additions & 0 deletions src/sparseml/export/export_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import shutil
import tarfile
from enum import Enum
from pathlib import Path
from typing import List, Optional, Tuple, Union

from tqdm import tqdm


__all__ = ["create_data_samples"]

_LOGGER = logging.getLogger(__name__)


class LabelNames(Enum):
basename = "sample-labels"
filename = "lab"


class OutputsNames(Enum):
basename = "sample-outputs"
filename = "out"


class InputsNames(Enum):
basename = "sample-inputs"
filename = "inp"


def create_data_samples(
data_loader: "torch.utils.data.DataLoader", # noqa F821
model: Optional["torch.nn.Module"] = None, # noqa F821
num_samples: int = 1,
) -> Tuple[
List["torch.Tensor"], # noqa F821
Optional[List["torch.Tensor"]], # noqa F821
List["torch.Tensor"], # noqa F821
]:
"""
Fetch a batch of samples from the data loader and return the inputs and outputs
:param data_loader: The data loader to get a batch of inputs/outputs from.
:param num_samples: The number of samples to generate. Defaults to 1
:return: The inputs and outputs as lists of torch tensors
"""
inputs, outputs, labels = [], [], []
for batch_num, (inputs_, labels_) in tqdm(enumerate(data_loader)):
if batch_num == num_samples:
break
if model:
outputs_ = model(inputs_)
outputs.append(outputs_)
inputs.append(inputs_)
labels.append(labels_)

return inputs, outputs, labels


def export_data_samples(
target_path: Union[Path, str],
input_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
output_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
label_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
as_tar: bool = False,
):
"""
Save the input, labels and output samples to the target path.
All the input files are optional. If a sample is None,
it will not be saved.
Input samples will be saved to:
.../sample-inputs/inp_0001.npz
.../sample-inputs/inp_0002.npz
...
Output samples will be saved to:
.../sample-outputs/out_0001.npz
.../sample-outputs/out_0002.npz
...
Label samples will be saved to:
.../sample-labels/lab_0001.npz
.../sample-labels/lab_0002.npz
...
If as_tar is True, the samples will be saved as tar files:
.../sample-inputs.tar.gz
.../sample-outputs.tar.gz
.../sample-labels.tar.gz
:param input_samples: The input samples to save.
:param output_samples: The output samples to save.
:param target_path: The path to save the samples to.
:param as_tar: Whether to save the samples as tar files.
"""

for samples, names in zip(
[input_samples, output_samples, label_samples],
[InputsNames, OutputsNames, LabelNames],
):
if samples is not None:
_LOGGER.info(f"Exporting {names.basename.value} to {target_path}")
export_data_sample(samples, names, target_path, as_tar)


def export_data_sample(
samples, names: Enum, target_path: Union[Path, str], as_tar: bool = False
):
from sparseml.pytorch.utils.helpers import tensors_export, tensors_to_device

samples = tensors_to_device(samples, "cpu")

tensors_export(
tensors=samples,
export_dir=os.path.join(target_path, names.basename.value),
name_prefix=names.filename.value,
)
if as_tar:
folder_path = os.path.join(target_path, names.basename.value)
with tarfile.open(folder_path + ".tar.gz", "w:gz") as tar:
tar.add(folder_path, arcname=os.path.basename(folder_path))
shutil.rmtree(folder_path)
1 change: 0 additions & 1 deletion src/sparseml/export/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from enum import Enum
from pathlib import Path
Expand Down
26 changes: 21 additions & 5 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field

from sparseml.export.export_torch_model import export_model
from sparseml.export.export_data import create_data_samples as create_data_samples_
from sparsezoo.utils.registry import RegistryMixin


Expand Down Expand Up @@ -75,10 +75,26 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
graph_optimizations: Optional[Dict[str, Callable]] = Field(
description="A mapping from names to graph optimization functions "
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."

create_data_samples: Callable[
Tuple[
Optional["torch.nn.Module"], "torch.utils.data.DataLoader", int # noqa F821
],
Tuple[
List["torch.Tensor"], # noqa F821
Optional[List["torch.Tensor"]], # noqa F821
List["torch.Tensor"], # noqa F821
],
] = Field(
default=create_data_samples_,
description="A function that takes: "
" - an optional (sparse) PyTorch model "
" - a data loader "
" - the number of samples to generate "
"and returns: "
" - the inputs, labels and (optionally) outputs as torch tensors ",
)

create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
Expand Down
121 changes: 121 additions & 0 deletions tests/sparseml/export/test_export_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tarfile
from enum import Enum

import pytest

from sparseml.export.export_data import create_data_samples, export_data_sample


@pytest.fixture()
def dummy_names():
class LabelNames(Enum):
basename = "sample-dummies"
filename = "dummy"

return LabelNames


@pytest.fixture()
def dummy_samples():
import torch

num_samples = 5
batch_size = 3
samples = [torch.randn(batch_size, 3, 224, 224) for _ in range(num_samples)]

return samples


@pytest.mark.parametrize(
"as_tar",
[True, False],
)
def test_export_data_sample(tmp_path, as_tar, dummy_names, dummy_samples):
export_data_sample(
samples=dummy_samples, names=dummy_names, target_path=tmp_path, as_tar=as_tar
)

dir_name = dummy_names.basename.value
dir_name_tar = dummy_names.basename.value + ".tar.gz"

if as_tar:
with tarfile.open(os.path.join(tmp_path, dir_name_tar)) as tar:
tar.extractall(path=tmp_path)

assert (
set(os.listdir(tmp_path)) == {dir_name}
if not as_tar
else {dir_name, dir_name_tar}
)
assert set(os.listdir(os.path.join(tmp_path, "sample-dummies"))) == {
f"dummy-000{i}.npz" for i in range(len(dummy_samples))
}


@pytest.mark.parametrize(
"model",
[True, False],
)
@pytest.mark.parametrize("num_samples", [0, 1, 5])
def test_create_data_samples(num_samples, model):
pytest.importorskip("torch", reason="test requires pytorch")

import torch
from torch.utils.data import DataLoader, Dataset

model = torch.nn.Sequential(torch.nn.Identity()) if model else None

class DummyDataset(Dataset):
def __init__(self, inputs, outputs):
self.data = inputs
self.target = outputs

def __len__(self):
return len(self.data)

def __getitem__(self, index):
data_sample = self.data[index]
target_sample = self.target[index]

return data_sample, target_sample

inputs = torch.randn((100, 3, 224, 224))
labels = torch.randint(
0,
10,
(
100,
50,
),
)

custom_dataset = DummyDataset(inputs, labels)

data_loader = DataLoader(custom_dataset, batch_size=1)

inputs, outputs, labels = create_data_samples(
data_loader=data_loader, num_samples=num_samples, model=model
)

assert all(tuple(input.shape) == (1, 3, 224, 224) for input in inputs)
assert all(tuple(label.shape) == (1, 50) for label in labels)
assert len(inputs) == num_samples == len(labels)

if model is not None:
assert len(outputs) == num_samples
assert all(tuple(output.shape) == (1, 3, 224, 224) for output in outputs)
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ def test_integration_helper_functions():
assert image_classification.create_dummy_input
assert image_classification.export
assert image_classification.graph_optimizations is None
assert image_classification.create_data_samples