Skip to content

Commit

Permalink
Adding --one-shot argument to torchvision export (#1300)
Browse files Browse the repository at this point in the history
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
  • Loading branch information
2 people authored and Benjamin committed Feb 2, 2023
1 parent ba4184b commit f1ace45
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/sparseml/pytorch/torchvision/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import click
from sparseml.pytorch.models.registry import ModelRegistry
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.optim.manager import ScheduledModifierManager
from sparseml.pytorch.torchvision import presets
from sparseml.pytorch.utils import ModuleExporter
from sparseml.pytorch.utils.model import load_model
Expand Down Expand Up @@ -60,6 +61,12 @@
help="The root dir path where the dataset is stored or should "
"be downloaded to if available",
)
@click.option(
"--one-shot",
default=None,
type=str,
help="Path to recipe to use to apply in a one-shot manner",
)
@click.option(
"--labels-to-class-mapping",
type=click.Path(dir_okay=False, file_okay=True, exists=True, path_type=Path),
Expand Down Expand Up @@ -118,6 +125,7 @@ def main(
arch_key: str,
checkpoint_path: str,
dataset_path: Path,
one_shot: Optional[str],
labels_to_class_mapping: Optional[Path],
num_samples: int,
onnx_opset: int,
Expand Down Expand Up @@ -159,6 +167,9 @@ def main(

load_model(checkpoint_path, model, strict=True)

if one_shot is not None:
ScheduledModifierManager.from_yaml(one_shot).apply(model)

if labels_to_class_mapping is not None:
with open(labels_to_class_mapping) as fp:
labels_to_class_mapping = json.load(fp)
Expand Down

0 comments on commit f1ace45

Please sign in to comment.