Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Dec 8, 2023
1 parent 8b2fca0 commit ff52598
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 13 deletions.
27 changes: 21 additions & 6 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pathlib import Path
from typing import Any, List, Optional, Union

from sparseml.export.helpers import apply_optimizations
from sparseml.export.helpers import apply_optimizations, export_sample_inputs_outputs
from sparseml.exporters import ExportTargets
from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Expand All @@ -42,7 +42,7 @@ def export(
single_graph_file: bool = True,
graph_optimizations: Union[str, List[str], None] = "all",
validate_correctness: bool = False,
export_sample_inputs_outputs: bool = False,
num_export_samples: int = 0,
deployment_directory_name: str = "deployment",
device: str = "auto",
):
Expand Down Expand Up @@ -81,8 +81,8 @@ def export(
to the exported model. Defaults to 'all'.
:param validate_correctness: Whether to validate the correctness
of the exported model. Defaults to False.
:param export_sample_inputs_outputs: Whether to export sample
inputs and outputs for the exported model.Defaults to False.
:param num_export_samples: The number of samples to export for
the exported model. Defaults to 0.
:param deployment_directory_name: The name of the deployment
directory to create for the exported model. Thus, the exported
model will be saved to `target_path/deployment_directory_name`.
Expand Down Expand Up @@ -123,8 +123,23 @@ def export(
single_graph_file=single_graph_file,
)

if export_sample_inputs_outputs:
helper_functions.export_sample_inputs_outputs(model, target_path)
if num_export_samples:
data_loader = auxiliary_items.get("validation_loader")
if data_loader is None:
raise ValueError(
"To export sample inputs/outputs a data loader is needed."
"To enable the export, provide a `validatation_loader` "
"as a part of `auxiliary_items` output of the `create_model` function."
)
input_samples, output_samples = helper_functions.create_sample_inputs_outputs(
num_samples=num_export_samples, data_loader=data_loader
)
export_sample_inputs_outputs(
input_samples=input_samples,
output_samples=output_samples,
target_path=target_path,
as_tar=True,
)

deployment_path = helper_functions.create_deployment_folder(
source_path, target_path, deployment_directory_name
Expand Down
69 changes: 67 additions & 2 deletions src/sparseml/export/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tarfile
from collections import OrderedDict
from enum import Enum
from pathlib import Path
Expand All @@ -22,7 +24,7 @@
from sparsezoo.utils.onnx import save_onnx


__all__ = ["apply_optimizations"]
__all__ = ["apply_optimizations", "export_sample_inputs_outputs"]


class GraphOptimizationOptions(Enum):
Expand All @@ -34,6 +36,69 @@ class GraphOptimizationOptions(Enum):
all = "all"


class OutputsNames(Enum):
basename = "sample-outputs"
filename = "out"


class InputsNames(Enum):
basename = "sample-inputs"
filename = "inp"


def export_sample_inputs_outputs(
input_samples: List["torch.Tensor"], # noqa F821
output_samples: List["torch.Tensor"], # noqa F821
target_path: Union[Path, str],
as_tar: bool = False,
):
"""
Save the input and output samples to the target path.
Input samples will be saved to:
.../sample-inputs/inp_0001.npz
.../sample-inputs/inp_0002.npz
...
Output samples will be saved to:
.../sample-outputs/out_0001.npz
.../sample-outputs/out_0002.npz
...
If as_tar is True, the samples will be saved as tar files:
.../sample-inputs.tar.gz
.../sample-outputs.tar.gz
:param input_samples: The input samples to save.
:param output_samples: The output samples to save.
:param target_path: The path to save the samples to.
:param as_tar: Whether to save the samples as tar files.
"""

from sparseml.pytorch.utils.helpers import tensors_export, tensors_to_device

input_samples = tensors_to_device(input_samples, "cpu")
output_samples = tensors_to_device(output_samples, "cpu")

for tensors, names in zip(
[input_samples, output_samples], [InputsNames, OutputsNames]
):
tensors_export(
tensors=tensors,
export_dir=os.path.join(target_path, names.basename.value),
name_prefix=names.filename.value,
)
if as_tar:
for folder_name_to_tar in [
InputsNames.basename.value,
OutputsNames.basename.value,
]:
folder_path = os.path.join(target_path, folder_name_to_tar)
with tarfile.open(folder_path + ".tar.gz", "w:gz") as tar:
tar.add(folder_path, arcname=os.path.basename(folder_path))
shutil.rmtree(folder_path)


def apply_optimizations(
onnx_file_path: Union[str, Path],
available_optimizations: OrderedDict[str, Callable],
Expand Down
39 changes: 35 additions & 4 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field
from tqdm import tqdm

from sparsezoo.utils.registry import RegistryMixin

Expand All @@ -32,6 +33,27 @@ class Integrations(Enum):
image_classification = "image-classification"


def create_sample_inputs_outputs(
data_loader: "torch.utils.data.DataLoader", # noqa F821
num_samples: int = 1,
) -> Tuple[List["torch.Tensor"], List["torch.Tensor"]]: # noqa F821
"""
Fetch a batch of samples from the data loader and return the inputs and outputs
:param data_loader: The data loader to get a batch of inputs/outputs from.
:param num_samples: The number of samples to generate. Defaults to 1
:return: The inputs and outputs as lists of torch tensors
"""
inputs, outputs = [], []
for batch_num, data in tqdm(enumerate(data_loader)):
if batch_num == num_samples:
break
inputs.append(data[0])
outputs.append(data[1])

return inputs, outputs


class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"""
Registry that maps names to helper functions
Expand Down Expand Up @@ -74,10 +96,19 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
graph_optimizations: Optional[Dict[str, Callable]] = Field(
description="A mapping from names to graph optimization functions "
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
"a (sparse) PyTorch model."

create_sample_inputs_outputs: Callable[
Tuple["torch.utils.data.DataLoader", int], # noqa F821
Tuple[List["torch.Tensor"], List["torch.Tensor"]], # noqa F821
] = Field(
default=create_sample_inputs_outputs,
description="A function that takes: "
" - a data loader "
" - the number of samples to generate "
"and returns: "
" - the inputs and outputs as torch tensors ",
)

create_deployment_folder: Optional[Callable] = Field(
description="A function that creates a "
"deployment folder for the exporter ONNX model"
Expand Down
56 changes: 55 additions & 1 deletion tests/sparseml/export/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,71 @@
# limitations under the License.

import logging
import os
import tarfile
from collections import OrderedDict

import onnx
import pytest

from src.sparseml.export.helpers import apply_optimizations
from src.sparseml.export.helpers import (
apply_optimizations,
export_sample_inputs_outputs,
)
from tests.sparseml.exporters.transforms.test_onnx_transform import (
_create_model as create_dummy_onnx_file,
)


@pytest.mark.parametrize(
"as_tar",
[True, False],
)
def test_export_sample_inputs_outputs(tmp_path, as_tar):
pytest.importorskip("torch", reason="test requires pytorch")
import torch

batch_size = 3
num_samples = 5

input_samples = [torch.randn(batch_size, 3, 224, 224) for _ in range(num_samples)]
output_samples = [torch.randn(batch_size, 1000) for _ in range(num_samples)]

export_sample_inputs_outputs(
input_samples=input_samples,
output_samples=output_samples,
target_path=tmp_path,
as_tar=as_tar,
)
dir_names = {"sample-inputs", "sample-outputs"}
dir_names_tar = {"sample-inputs.tar.gz", "sample-outputs.tar.gz"}

if as_tar:
assert set(os.listdir(tmp_path)) == dir_names_tar
# unpack the tar files
for dir_name in dir_names_tar:
with tarfile.open(os.path.join(tmp_path, dir_name)) as tar:
tar.extractall(path=tmp_path)

assert set(os.listdir(tmp_path)) == (
dir_names if not as_tar else dir_names_tar | dir_names
)
assert set(os.listdir(os.path.join(tmp_path, "sample-inputs"))) == {
"inp-0000.npz",
"inp-0001.npz",
"inp-0002.npz",
"inp-0003.npz",
"inp-0004.npz",
}
assert set(os.listdir(os.path.join(tmp_path, "sample-outputs"))) == {
"out-0000.npz",
"out-0001.npz",
"out-0002.npz",
"out-0003.npz",
"out-0004.npz",
}


def foo(onnx_model):
logging.debug("foo")
return onnx_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ def test_integration_helper_functions():
assert image_classification.create_dummy_input
assert image_classification.export
assert image_classification.graph_optimizations is None
assert image_classification.create_sample_inputs_outputs
59 changes: 59 additions & 0 deletions tests/sparseml/test_integration_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from sparseml.integration_helper_functions import create_sample_inputs_outputs


@pytest.mark.parametrize("num_samples", [0, 1, 5])
def test_create_sample_inputs_outputs(num_samples):
pytest.importorskip("torch", reason="test requires pytorch")
import torch
from torch.utils.data import DataLoader, Dataset

class DummmyDataset(Dataset):
def __init__(self, inputs, outputs):
self.data = inputs
self.target = outputs

def __len__(self):
return len(self.data)

def __getitem__(self, index):
data_sample = self.data[index]
target_sample = self.target[index]

return data_sample, target_sample

inputs = torch.randn((100, 3, 224, 224))
outputs = torch.randint(
0,
10,
(
100,
50,
),
)

custom_dataset = DummmyDataset(inputs, outputs)

data_loader = DataLoader(custom_dataset, batch_size=1)

inputs, outputs = create_sample_inputs_outputs(data_loader, num_samples)

assert all(tuple(input.shape) == (1, 3, 224, 224) for input in inputs)
assert all(tuple(output.shape) == (1, 50) for output in outputs)

assert len(inputs) == num_samples == len(outputs)

0 comments on commit ff52598

Please sign in to comment.