Skip to content

Commit

Permalink
Support to Generate Fake Sample/Inputs (#1180) (#1197)
Browse files Browse the repository at this point in the history
* Support to Generate Fake Sample/Inputs and outputs if no `--data_args` supplied in export script

* Address all review comments
Simplify call to `_get_fake_inputs`
Save inputs/outputs to `sample-inputs`/`sample-outputs`
  • Loading branch information
rahul-tuli authored Dec 2, 2022
1 parent 238b6ef commit 68f2a2f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 15 deletions.
20 changes: 13 additions & 7 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,10 @@ def export_transformer_to_onnx(
)

if num_export_samples > 0 and data_args is None:
raise ValueError(
_LOGGER.info(
f"--data_args is needed for exporting {num_export_samples} "
f"samples but got {data_args}"
"real samples but got None, synthetic data samples will be "
"generated based on model input/output shapes"
)
data_args: Dict[str, Any] = _parse_data_args(data_args)

Expand All @@ -265,7 +266,7 @@ def export_transformer_to_onnx(
_LOGGER.info(f"loaded model, config, and tokenizer from {model_path}")

eval_dataset = None
if num_export_samples > 0:
if num_export_samples > 0 and data_args:
tokenized_dataset = load_task_dataset(
task=task,
tokenizer=tokenizer,
Expand Down Expand Up @@ -316,12 +317,16 @@ def export_transformer_to_onnx(
# Rearrange inputs' keys to match those defined by model foward func, which
# seem to define how the order of inputs is determined in the exported model
forward_args_spec = inspect.getfullargspec(model.__class__.forward)
dropped = [f for f in inputs.keys() if f not in forward_args_spec.args]
dropped = [
input_key
for input_key in inputs.keys()
if input_key not in forward_args_spec.args
]
inputs = collections.OrderedDict(
[
(f, inputs[f][0].reshape(1, -1))
for f in forward_args_spec.args
if f in inputs
(func_input_arg_name, inputs[func_input_arg_name][0].reshape(1, -1))
for func_input_arg_name in forward_args_spec.args
if func_input_arg_name in inputs
]
)
if dropped:
Expand Down Expand Up @@ -362,6 +367,7 @@ def export_transformer_to_onnx(
_LOGGER.info(f"Exporting {num_export_samples} sample inputs/outputs")
trainer.save_sample_inputs_outputs(
num_samples_to_export=num_export_samples,
tokenizer=tokenizer,
)

_LOGGER.info(f"{num_export_samples} sample inputs/outputs exported")
Expand Down
71 changes: 63 additions & 8 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
SparseML transformers trainer classes and interfaces to be plugged in with
existing or similiar HF trainer flows
"""
import collections
import inspect
import logging
import math
import os
import warnings
from contextlib import suppress
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Final, List, Optional, Tuple, Union

import datasets
import numpy
Expand All @@ -36,6 +38,7 @@
from transformers.trainer_callback import TrainerState
from transformers.trainer_pt_utils import reissue_pt_warnings
from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint
from transformers.utils import PaddingStrategy

from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
from sparseml.pytorch.utils import (
Expand All @@ -56,7 +59,6 @@
"TransformersTrainer",
]


_LOGGER = logging.getLogger(__name__)
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
Expand Down Expand Up @@ -489,31 +491,56 @@ def log_model_sparsification(self):
)

def save_sample_inputs_outputs(
self, num_samples_to_export: int = 100, output_dir: Optional[str] = None
self,
num_samples_to_export: int = 100,
output_dir: Optional[str] = None,
tokenizer: Optional[Any] = None,
):
"""
Save sample inputs/outputs/labels in save_dir as .npz arrays
:param num_samples_to_export: Number of samples to export.
Defaults to 100
:param output_dir: The directory to store sample inputs and outputs in
:param tokenizer: if eval and train dataset cannot be generated, then
the tokenizer is used to generate fake inputs
"""
num_samples = 0
output_dir = output_dir or self.args.output_dir or ""

sample_in_dir = os.path.join(output_dir, "sample_inputs")
sample_out_dir = os.path.join(output_dir, "sample_outputs")
if output_dir is None:
output_dir = (
self.args.output_dir if hasattr(self.args, "output_dir") else ""
)

sample_in_dir = os.path.join(output_dir, "sample-inputs")
sample_out_dir = os.path.join(output_dir, "sample-outputs")

os.makedirs(sample_in_dir, exist_ok=True)
os.makedirs(sample_out_dir, exist_ok=True)
device = self.model.device

dataloader = None
try:
dataloader = self.get_eval_dataloader()
except Exception:
dataloader = self.get_train_dataloader()
with suppress(ValueError):
dataloader = self.get_train_dataloader()

if not dataloader and not tokenizer:
raise ValueError(
"tokenizer is needed to generate fake sample inputs when Trainer is "
"not initialized with a train or eval dataset"
)
if dataloader is None:
# we have the tokenizer so use it
dataloader = self._get_fake_dataloader(
num_samples=num_samples_to_export, tokenizer=tokenizer
)

_LOGGER.info(f"Exporting {num_samples_to_export} samples to {output_dir}")
_LOGGER.info(
f"Exporting {num_samples_to_export} samples to "
f"{os.path.abspath(output_dir)}"
)
for _, sample_batch in enumerate(dataloader):
sample_batch.pop("labels", None)
input_names = list(sample_batch.keys())
Expand Down Expand Up @@ -725,6 +752,34 @@ def _add_tensorboard_logger_if_available(self):
TensorBoardLogger(writer=tensorboard_callback.tb_writer)
)

def _get_fake_dataloader(
self,
num_samples: int,
tokenizer: "PreTrainedTokenizerBase", # noqa: F821
):

# Rearrange inputs' keys to match those defined by model foward func, which
# seem to define how the order of inputs is determined in the exported model
forward_args_spec = inspect.getfullargspec(self.model.__class__.forward)
synthetic_input: Final = self._get_fake_input(
forward_func_input_keys=forward_args_spec.args,
tokenizer=tokenizer,
)
return (synthetic_input for _ in range(num_samples))

def _get_fake_input(self, forward_func_input_keys, tokenizer):
inputs = tokenizer(
"", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value
).data # Dict[Tensor]
inputs = collections.OrderedDict(
[
(input_key, inputs[input_key][0].reshape(1, -1))
for input_key in forward_func_input_keys
if input_key in inputs
]
)
return inputs


class TrainerInterface(RecipeManagerTrainerInterface):
"""
Expand Down

0 comments on commit 68f2a2f

Please sign in to comment.