Skip to content

Commit

Permalink
[Export Refactor][Image Classification] `export_sample_inputs_outputs…
Browse files Browse the repository at this point in the history
…` function (#1888)

* initial commit

* looking good, time to cleanup

* Delete src/sparseml/export/helpers.py

* Delete tests/sparseml/export/test_helpers.py

* ready for review

* improve design

* tests pass

* reuse _validate_dataset_num_classes

* initial commit

* Update src/sparseml/pytorch/image_classification/integration_helper_functions.py

* Update src/sparseml/pytorch/image_classification/integration_helper_functions.py

* ready for review

* Update src/sparseml/export/export.py

* Update src/sparseml/integration_helper_functions.py

* initial commit

* fixes

* ready for review

* nit

* add return

* initial commit

* initial commit

* PR comments

* beautification
  • Loading branch information
dbogunowicz committed Dec 29, 2023
1 parent 9096b0d commit 16c9bf3
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 13 deletions.
35 changes: 28 additions & 7 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, List, Optional, Union

from sparseml.export.helpers import apply_optimizations
from sparseml.export_data import export_data_samples
from sparseml.exporters import ExportTargets
from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Expand All @@ -42,7 +43,7 @@ def export(
single_graph_file: bool = True,
graph_optimizations: Union[str, List[str], None] = "all",
validate_correctness: bool = False,
export_sample_inputs_outputs: bool = False,
num_export_samples: int = 0,
deployment_directory_name: str = "deployment",
device: str = "auto",
):
Expand Down Expand Up @@ -81,8 +82,8 @@ def export(
to the exported model. Defaults to 'all'.
:param validate_correctness: Whether to validate the correctness
of the exported model. Defaults to False.
:param export_sample_inputs_outputs: Whether to export sample
inputs and outputs for the exported model.Defaults to False.
:param num_export_samples: The number of samples to export for
the exported model. Defaults to 0.
:param deployment_directory_name: The name of the deployment
directory to create for the exported model. Thus, the exported
model will be saved to `target_path/deployment_directory_name`.
Expand Down Expand Up @@ -123,18 +124,38 @@ def export(
single_graph_file=single_graph_file,
)

if export_sample_inputs_outputs:
helper_functions.export_sample_inputs_outputs(model, target_path)
if num_export_samples:
data_loader = auxiliary_items.get("validation_loader")
if data_loader is None:
raise ValueError(
"To export sample inputs/outputs a data loader is needed."
"To enable the export, provide a `validatation_loader` "
"as a part of `auxiliary_items` output of the `create_model` function."
)
(
input_samples,
output_samples,
label_samples,
) = helper_functions.create_data_samples(
num_samples=num_export_samples, data_loader=data_loader, model=model
)
export_data_samples(
input_samples=input_samples,
output_samples=output_samples,
label_samples=label_samples,
target_path=target_path,
as_tar=True,
)

deployment_path = helper_functions.create_deployment_folder(
source_path, target_path, deployment_directory_name
)

if validate_correctness:
if not export_sample_inputs_outputs:
if not num_export_samples:
raise ValueError(
"To validate correctness sample inputs/outputs are needed."
"To enable the validation, set `export_sample_inputs_outputs`"
"To enable the validation, set `num_export_samples`"
"to True"
)
validate_correctness(deployment_path)
Expand Down
138 changes: 138 additions & 0 deletions src/sparseml/export/export_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import shutil
import tarfile
from enum import Enum
from pathlib import Path
from typing import List, Optional, Tuple, Union

from tqdm import tqdm


__all__ = ["create_data_samples"]

_LOGGER = logging.getLogger(__name__)


class LabelNames(Enum):
basename = "sample-labels"
filename = "lab"


class OutputsNames(Enum):
basename = "sample-outputs"
filename = "out"


class InputsNames(Enum):
basename = "sample-inputs"
filename = "inp"


def create_data_samples(
data_loader: "torch.utils.data.DataLoader", # noqa F821
model: Optional["torch.nn.Module"] = None, # noqa F821
num_samples: int = 1,
) -> Tuple[
List["torch.Tensor"], # noqa F821
Optional[List["torch.Tensor"]], # noqa F821
List["torch.Tensor"], # noqa F821
]:
"""
Fetch a batch of samples from the data loader and return the inputs and outputs
:param data_loader: The data loader to get a batch of inputs/outputs from.
:param num_samples: The number of samples to generate. Defaults to 1
:return: The inputs and outputs as lists of torch tensors
"""
inputs, outputs, labels = [], [], []
for batch_num, (inputs_, labels_) in tqdm(enumerate(data_loader)):
if batch_num == num_samples:
break
if model:
outputs_ = model(inputs_)
outputs.append(outputs_)
inputs.append(inputs_)
labels.append(labels_)

return inputs, outputs, labels


def export_data_samples(
target_path: Union[Path, str],
input_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
output_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
label_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
as_tar: bool = False,
):
"""
Save the input, labels and output samples to the target path.
All the input files are optional. If a sample is None,
it will not be saved.
Input samples will be saved to:
.../sample-inputs/inp_0001.npz
.../sample-inputs/inp_0002.npz
...
Output samples will be saved to:
.../sample-outputs/out_0001.npz
.../sample-outputs/out_0002.npz
...
Label samples will be saved to:
.../sample-labels/lab_0001.npz
.../sample-labels/lab_0002.npz
...
If as_tar is True, the samples will be saved as tar files:
.../sample-inputs.tar.gz
.../sample-outputs.tar.gz
.../sample-labels.tar.gz
:param input_samples: The input samples to save.
:param output_samples: The output samples to save.
:param target_path: The path to save the samples to.
:param as_tar: Whether to save the samples as tar files.
"""

for samples, names in zip(
[input_samples, output_samples, label_samples],
[InputsNames, OutputsNames, LabelNames],
):
if samples is not None:
_LOGGER.info(f"Exporting {names.basename.value} to {target_path}")
export_data_sample(samples, names, target_path, as_tar)


def export_data_sample(
samples, names: Enum, target_path: Union[Path, str], as_tar: bool = False
):
from sparseml.pytorch.utils.helpers import tensors_export, tensors_to_device

samples = tensors_to_device(samples, "cpu")

tensors_export(
tensors=samples,
export_dir=os.path.join(target_path, names.basename.value),
name_prefix=names.filename.value,
)
if as_tar:
folder_path = os.path.join(target_path, names.basename.value)
with tarfile.open(folder_path + ".tar.gz", "w:gz") as tar:
tar.add(folder_path, arcname=os.path.basename(folder_path))
shutil.rmtree(folder_path)
1 change: 0 additions & 1 deletion src/sparseml/export/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from enum import Enum
from pathlib import Path
Expand Down
26 changes: 21 additions & 5 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field

from sparseml.export.export_torch_model import export_model
from sparseml.export.export_data import create_data_samples as create_data_samples_
from sparsezoo.utils.registry import RegistryMixin


Expand Down Expand Up @@ -75,10 +75,26 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
graph_optimizations: Optional[Dict[str, Callable]] = Field(
description="A mapping from names to graph optimization functions "
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."

create_data_samples: Callable[
Tuple[
Optional["torch.nn.Module"], "torch.utils.data.DataLoader", int # noqa F821
],
Tuple[
List["torch.Tensor"], # noqa F821
Optional[List["torch.Tensor"]], # noqa F821
List["torch.Tensor"], # noqa F821
],
] = Field(
default=create_data_samples_,
description="A function that takes: "
" - an optional (sparse) PyTorch model "
" - a data loader "
" - the number of samples to generate "
"and returns: "
" - the inputs, labels and (optionally) outputs as torch tensors ",
)

create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
Expand Down
121 changes: 121 additions & 0 deletions tests/sparseml/export/test_export_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tarfile
from enum import Enum

import pytest

from sparseml.export.export_data import create_data_samples, export_data_sample


@pytest.fixture()
def dummy_names():
class LabelNames(Enum):
basename = "sample-dummies"
filename = "dummy"

return LabelNames


@pytest.fixture()
def dummy_samples():
import torch

num_samples = 5
batch_size = 3
samples = [torch.randn(batch_size, 3, 224, 224) for _ in range(num_samples)]

return samples


@pytest.mark.parametrize(
"as_tar",
[True, False],
)
def test_export_data_sample(tmp_path, as_tar, dummy_names, dummy_samples):
export_data_sample(
samples=dummy_samples, names=dummy_names, target_path=tmp_path, as_tar=as_tar
)

dir_name = dummy_names.basename.value
dir_name_tar = dummy_names.basename.value + ".tar.gz"

if as_tar:
with tarfile.open(os.path.join(tmp_path, dir_name_tar)) as tar:
tar.extractall(path=tmp_path)

assert (
set(os.listdir(tmp_path)) == {dir_name}
if not as_tar
else {dir_name, dir_name_tar}
)
assert set(os.listdir(os.path.join(tmp_path, "sample-dummies"))) == {
f"dummy-000{i}.npz" for i in range(len(dummy_samples))
}


@pytest.mark.parametrize(
"model",
[True, False],
)
@pytest.mark.parametrize("num_samples", [0, 1, 5])
def test_create_data_samples(num_samples, model):
pytest.importorskip("torch", reason="test requires pytorch")

import torch
from torch.utils.data import DataLoader, Dataset

model = torch.nn.Sequential(torch.nn.Identity()) if model else None

class DummyDataset(Dataset):
def __init__(self, inputs, outputs):
self.data = inputs
self.target = outputs

def __len__(self):
return len(self.data)

def __getitem__(self, index):
data_sample = self.data[index]
target_sample = self.target[index]

return data_sample, target_sample

inputs = torch.randn((100, 3, 224, 224))
labels = torch.randint(
0,
10,
(
100,
50,
),
)

custom_dataset = DummyDataset(inputs, labels)

data_loader = DataLoader(custom_dataset, batch_size=1)

inputs, outputs, labels = create_data_samples(
data_loader=data_loader, num_samples=num_samples, model=model
)

assert all(tuple(input.shape) == (1, 3, 224, 224) for input in inputs)
assert all(tuple(label.shape) == (1, 50) for label in labels)
assert len(inputs) == num_samples == len(labels)

if model is not None:
assert len(outputs) == num_samples
assert all(tuple(output.shape) == (1, 3, 224, 224) for output in outputs)
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ def test_integration_helper_functions():
assert image_classification.create_dummy_input
assert image_classification.export
assert image_classification.graph_optimizations is None
assert image_classification.create_data_samples

0 comments on commit 16c9bf3

Please sign in to comment.