-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Export Refactor][Image Classification] `export_sample_inputs_outputs…
…` function (#1888) * initial commit * looking good, time to cleanup * Delete src/sparseml/export/helpers.py * Delete tests/sparseml/export/test_helpers.py * ready for review * improve design * tests pass * reuse _validate_dataset_num_classes * initial commit * Update src/sparseml/pytorch/image_classification/integration_helper_functions.py * Update src/sparseml/pytorch/image_classification/integration_helper_functions.py * ready for review * Update src/sparseml/export/export.py * Update src/sparseml/integration_helper_functions.py * initial commit * fixes * ready for review * nit * add return * initial commit * initial commit * PR comments * beautification
- Loading branch information
1 parent
9096b0d
commit 16c9bf3
Showing
6 changed files
with
309 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import os | ||
import shutil | ||
import tarfile | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import List, Optional, Tuple, Union | ||
|
||
from tqdm import tqdm | ||
|
||
|
||
__all__ = ["create_data_samples"] | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class LabelNames(Enum): | ||
basename = "sample-labels" | ||
filename = "lab" | ||
|
||
|
||
class OutputsNames(Enum): | ||
basename = "sample-outputs" | ||
filename = "out" | ||
|
||
|
||
class InputsNames(Enum): | ||
basename = "sample-inputs" | ||
filename = "inp" | ||
|
||
|
||
def create_data_samples( | ||
data_loader: "torch.utils.data.DataLoader", # noqa F821 | ||
model: Optional["torch.nn.Module"] = None, # noqa F821 | ||
num_samples: int = 1, | ||
) -> Tuple[ | ||
List["torch.Tensor"], # noqa F821 | ||
Optional[List["torch.Tensor"]], # noqa F821 | ||
List["torch.Tensor"], # noqa F821 | ||
]: | ||
""" | ||
Fetch a batch of samples from the data loader and return the inputs and outputs | ||
:param data_loader: The data loader to get a batch of inputs/outputs from. | ||
:param num_samples: The number of samples to generate. Defaults to 1 | ||
:return: The inputs and outputs as lists of torch tensors | ||
""" | ||
inputs, outputs, labels = [], [], [] | ||
for batch_num, (inputs_, labels_) in tqdm(enumerate(data_loader)): | ||
if batch_num == num_samples: | ||
break | ||
if model: | ||
outputs_ = model(inputs_) | ||
outputs.append(outputs_) | ||
inputs.append(inputs_) | ||
labels.append(labels_) | ||
|
||
return inputs, outputs, labels | ||
|
||
|
||
def export_data_samples( | ||
target_path: Union[Path, str], | ||
input_samples: Optional[List["torch.Tensor"]] = None, # noqa F821 | ||
output_samples: Optional[List["torch.Tensor"]] = None, # noqa F821 | ||
label_samples: Optional[List["torch.Tensor"]] = None, # noqa F821 | ||
as_tar: bool = False, | ||
): | ||
""" | ||
Save the input, labels and output samples to the target path. | ||
All the input files are optional. If a sample is None, | ||
it will not be saved. | ||
Input samples will be saved to: | ||
.../sample-inputs/inp_0001.npz | ||
.../sample-inputs/inp_0002.npz | ||
... | ||
Output samples will be saved to: | ||
.../sample-outputs/out_0001.npz | ||
.../sample-outputs/out_0002.npz | ||
... | ||
Label samples will be saved to: | ||
.../sample-labels/lab_0001.npz | ||
.../sample-labels/lab_0002.npz | ||
... | ||
If as_tar is True, the samples will be saved as tar files: | ||
.../sample-inputs.tar.gz | ||
.../sample-outputs.tar.gz | ||
.../sample-labels.tar.gz | ||
:param input_samples: The input samples to save. | ||
:param output_samples: The output samples to save. | ||
:param target_path: The path to save the samples to. | ||
:param as_tar: Whether to save the samples as tar files. | ||
""" | ||
|
||
for samples, names in zip( | ||
[input_samples, output_samples, label_samples], | ||
[InputsNames, OutputsNames, LabelNames], | ||
): | ||
if samples is not None: | ||
_LOGGER.info(f"Exporting {names.basename.value} to {target_path}") | ||
export_data_sample(samples, names, target_path, as_tar) | ||
|
||
|
||
def export_data_sample( | ||
samples, names: Enum, target_path: Union[Path, str], as_tar: bool = False | ||
): | ||
from sparseml.pytorch.utils.helpers import tensors_export, tensors_to_device | ||
|
||
samples = tensors_to_device(samples, "cpu") | ||
|
||
tensors_export( | ||
tensors=samples, | ||
export_dir=os.path.join(target_path, names.basename.value), | ||
name_prefix=names.filename.value, | ||
) | ||
if as_tar: | ||
folder_path = os.path.join(target_path, names.basename.value) | ||
with tarfile.open(folder_path + ".tar.gz", "w:gz") as tar: | ||
tar.add(folder_path, arcname=os.path.basename(folder_path)) | ||
shutil.rmtree(folder_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import tarfile | ||
from enum import Enum | ||
|
||
import pytest | ||
|
||
from sparseml.export.export_data import create_data_samples, export_data_sample | ||
|
||
|
||
@pytest.fixture() | ||
def dummy_names(): | ||
class LabelNames(Enum): | ||
basename = "sample-dummies" | ||
filename = "dummy" | ||
|
||
return LabelNames | ||
|
||
|
||
@pytest.fixture() | ||
def dummy_samples(): | ||
import torch | ||
|
||
num_samples = 5 | ||
batch_size = 3 | ||
samples = [torch.randn(batch_size, 3, 224, 224) for _ in range(num_samples)] | ||
|
||
return samples | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"as_tar", | ||
[True, False], | ||
) | ||
def test_export_data_sample(tmp_path, as_tar, dummy_names, dummy_samples): | ||
export_data_sample( | ||
samples=dummy_samples, names=dummy_names, target_path=tmp_path, as_tar=as_tar | ||
) | ||
|
||
dir_name = dummy_names.basename.value | ||
dir_name_tar = dummy_names.basename.value + ".tar.gz" | ||
|
||
if as_tar: | ||
with tarfile.open(os.path.join(tmp_path, dir_name_tar)) as tar: | ||
tar.extractall(path=tmp_path) | ||
|
||
assert ( | ||
set(os.listdir(tmp_path)) == {dir_name} | ||
if not as_tar | ||
else {dir_name, dir_name_tar} | ||
) | ||
assert set(os.listdir(os.path.join(tmp_path, "sample-dummies"))) == { | ||
f"dummy-000{i}.npz" for i in range(len(dummy_samples)) | ||
} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model", | ||
[True, False], | ||
) | ||
@pytest.mark.parametrize("num_samples", [0, 1, 5]) | ||
def test_create_data_samples(num_samples, model): | ||
pytest.importorskip("torch", reason="test requires pytorch") | ||
|
||
import torch | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
model = torch.nn.Sequential(torch.nn.Identity()) if model else None | ||
|
||
class DummyDataset(Dataset): | ||
def __init__(self, inputs, outputs): | ||
self.data = inputs | ||
self.target = outputs | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
data_sample = self.data[index] | ||
target_sample = self.target[index] | ||
|
||
return data_sample, target_sample | ||
|
||
inputs = torch.randn((100, 3, 224, 224)) | ||
labels = torch.randint( | ||
0, | ||
10, | ||
( | ||
100, | ||
50, | ||
), | ||
) | ||
|
||
custom_dataset = DummyDataset(inputs, labels) | ||
|
||
data_loader = DataLoader(custom_dataset, batch_size=1) | ||
|
||
inputs, outputs, labels = create_data_samples( | ||
data_loader=data_loader, num_samples=num_samples, model=model | ||
) | ||
|
||
assert all(tuple(input.shape) == (1, 3, 224, 224) for input in inputs) | ||
assert all(tuple(label.shape) == (1, 50) for label in labels) | ||
assert len(inputs) == num_samples == len(labels) | ||
|
||
if model is not None: | ||
assert len(outputs) == num_samples | ||
assert all(tuple(output.shape) == (1, 3, 224, 224) for output in outputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters