Skip to content

Commit

Permalink
[Export][Transformers] Implementation of correctness validation (#1935)
Browse files Browse the repository at this point in the history
* fix tests with help from sara

* Update src/sparseml/transformers/utils/initializers.py

* swap sparsezoo validator for custom one (top k match)

* add more informative error message

* add correctness validation for LLMs

* remove past_key_values from outputs

* remove past_key_values from outputs (2)

* small note comment for the future
  • Loading branch information
dbogunowicz authored Jan 5, 2024
1 parent adfcd8f commit e0c1068
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 146 deletions.
25 changes: 11 additions & 14 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def export(
opset: int = TORCH_DEFAULT_ONNX_OPSET,
single_graph_file: bool = True,
num_export_samples: int = 0,
batch_size: int = 1,
recipe: Optional[Union[Path, str]] = None,
deployment_directory_name: str = "deployment",
device: str = "cpu",
Expand Down Expand Up @@ -90,8 +89,6 @@ def export(
file. Defaults to True.
:param num_export_samples: The number of samples to create for
the exported model. Defaults to 0.
:param batch_size: The batch size to use for exporting the data.
Defaults to None.
:param deployment_directory_name: The name of the deployment
directory to create for the exported model. Thus, the exported
model will be saved to `target_path/deployment_directory_name`.
Expand Down Expand Up @@ -157,7 +154,6 @@ def export(
source_path,
device=device,
task=task,
batch_size=batch_size,
recipe=recipe,
**kwargs,
)
Expand Down Expand Up @@ -223,16 +219,6 @@ def export(
onnx_model_name=onnx_model_name,
)

_LOGGER.info(
f"Applying optimizations: {graph_optimizations} to the exported model..."
)
if helper_functions.apply_optimizations is not None:
helper_functions.apply_optimizations(
exported_file_path=os.path.join(deployment_path, onnx_model_name),
optimizations=graph_optimizations,
single_graph_file=single_graph_file,
)

if validate_structure:
_LOGGER.info("Validating model structure...")
validate_structure_(
Expand All @@ -253,6 +239,17 @@ def export(
)
validate_correctness_(target_path, deployment_path, onnx_model_name)

_LOGGER.info(
f"Applying optimizations: {graph_optimizations} to the exported model..."
)

if helper_functions.apply_optimizations is not None:
helper_functions.apply_optimizations(
exported_file_path=os.path.join(deployment_path, onnx_model_name),
optimizations=graph_optimizations,
single_graph_file=single_graph_file,
)

_LOGGER.info(
f"Successfully exported model from:\n{target_path}"
f"\nto\n{deployment_path}\nfor integration: {integration}"
Expand Down
59 changes: 27 additions & 32 deletions src/sparseml/export/export_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import shutil
import tarfile
from collections import OrderedDict
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -166,6 +167,17 @@ def create_data_samples(
else labels_
)

# turn all the returned lists into a list of dicts
# to facilitate the sample export
if inputs and not isinstance(inputs[0], dict):
inputs = [dict(input=input) for input in inputs]

if labels and not isinstance(labels[0], dict):
labels = [dict(label=label) for label in labels]

if outputs and not isinstance(outputs[0], dict):
outputs = [dict(output=output) for output in outputs]

return inputs, outputs, labels


Expand All @@ -176,65 +188,48 @@ def run_inference_with_dict_data(
Run inference on a model by inferring the appropriate
inputs from the dictionary input data.
:param data: The data to run inference on
:param model: The model to run inference on (optional)
:return: The inputs, labels and outputs
"""
labels = None
if model is None:
output = None

else:
inputs = {key: value.to(model.device) for key, value in data.items()}
# move the inputs to the model device and
# grab only the first sample from the batch
inputs = {
key: value[0].to(model.device).reshape(1, -1) for key, value in data.items()
}
output_vals = model(**inputs)
if "past_key_values" in output_vals.keys():
output_vals = _unnest_past_key_values(output_vals)
output = {
name: torch.squeeze(val).detach().to("cpu")
for name, val in output_vals.items()
}
inputs = {key: value.to("cpu") for key, value in data.items()}
inputs = {key: value.to("cpu")[0] for key, value in data.items()}
return inputs, labels, output


def _unnest_past_key_values(output_vals: Dict[str, Any]) -> Dict[str, Any]:
"""
Unnest the past key values from the output of the model.
(so they exist on the top level of the output dictionary)
By default the past key values are nested in a list of unnamed
tuples. This function unnests them and names them.
:param output_vals: The output of the model
:return: The output of the model with the past key values unpacked
"""

past_key_values = output_vals["past_key_values"]
output_vals = {
key: value for key, value in output_vals.items() if key != "past_key_values"
}
for i, past_key_values in enumerate(past_key_values):
key, value = past_key_values
output_vals[f"past_key_values_{i}_key"] = key
output_vals[f"past_key_values_{i}_value"] = value
return output_vals


# this function is specific for image-classification for now
# to be generalized later
def run_inference_with_tuple_or_list_data(
data: Tuple[Any, Any], model: Optional[torch.nn.Module] = None
) -> Tuple[torch.Tensor, Any, Optional[torch.Tensor]]:
"""
Run inference on a model by inferring the appropriate
inputs from the tuple input data.
:param inputs: The data to run inference on
:param data: The data to run inference on
:param model: The model to run inference on (optional)
:return: The inputs, labels and outputs
"""
# assume that
inputs, labels = data

outputs = model(inputs) if model else None
if isinstance(outputs, tuple):
# outputs_ contains (logits, softmax)
outputs = outputs[0]
# outputs_ contains (logits, scores)
outputs = OrderedDict(logits=outputs[0], scores=outputs[1])
if len(inputs.size()) == 4:
# if the input is a batch, remove the batch dimension
inputs = torch.squeeze(inputs, 0)
return inputs, labels, outputs
107 changes: 73 additions & 34 deletions src/sparseml/export/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import logging
import os.path
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import numpy
import onnxruntime as ort

from sparseml.export.export_data import InputsNames, LabelNames, OutputsNames
from sparseml.export.helpers import ONNX_MODEL_NAME
from sparsezoo.inference import InferenceRunner
from sparsezoo.objects import File, NumpyDirectory
from sparsezoo.utils.numpy import load_numpy


__all__ = ["validate_correctness", "validate_structure"]
Expand Down Expand Up @@ -98,47 +102,82 @@ def check_file_presence(file_paths: List[str]) -> List[str]:
return missing_files


# TODO: Need to add few changes to sparsezoo to support this function
def top_k_match(
ground_truth: numpy.ndarray, prediction: numpy.ndarray, k: int = 2
) -> bool:
"""
Checks if the top k predictions match the ground truth.
:param ground_truth: The ground truth array.
:param prediction: The prediction array.
:param k: The number of top predictions to consider.
"""
top_k_prediction = numpy.argsort(prediction.flatten())[-k:]
top_k_ground_truth = numpy.argsort(ground_truth.flatten())[-k:]
return numpy.all(top_k_prediction == top_k_ground_truth)


def validate_correctness(
target_path: Union[str, Path], directory: Union[str, Path], onnx_model_name: str
):
target_path: Union[str, Path],
directory: Union[str, Path],
onnx_model_name: str,
validation_function: Callable[..., bool] = top_k_match,
) -> bool:
"""
Validates the correctness of the exported ONNX model by
running it on a set of sample inputs and comparing the
resulting outputs with precomputed ground truth values.
resulting outputs using a validation function.
:param target_path: The directory where the sample inputs and outputs are stored.
:param directory: The directory where the ONNX model is stored.
:param onnx_model_name: The name of the ONNX model.
:param validation_function: The function that will be used to validate the outputs.
:return: True if the validation passes, False otherwise.
"""
# TODO: During testing add a support for tar.gz scenario (potentially)

sample_inputs_path = os.path.join(target_path, InputsNames.basename.value)
sample_outputs_path = os.path.join(target_path, OutputsNames.basename.value)

sample_inputs = NumpyDirectory(
name=InputsNames.basename.value,
files=[
File(name=file_name, path=os.path.join(sample_inputs_path, file_name))
for file_name in os.listdir(sample_inputs_path)
],
path=sample_inputs_path,
)
sample_outputs = NumpyDirectory(
name=OutputsNames.basename.value,
files=[
File(name=file_name, path=os.path.join(sample_outputs_path, file_name))
for file_name in os.listdir(sample_outputs_path)
],
path=sample_outputs_path,
)
onnx_model = File(
name=onnx_model_name, path=os.path.join(directory, onnx_model_name)
sample_inputs_files = sorted(glob.glob(os.path.join(sample_inputs_path, "*")))
sample_outputs_files = sorted(glob.glob(os.path.join(sample_outputs_path, "*")))

session = ort.InferenceSession(os.path.join(directory, onnx_model_name))

validations = (
[]
) # stores boolean per sample pair (True if validation passes, False otherwise)

for sample_input_file, sample_output_file in zip(
sample_inputs_files, sample_outputs_files
):
sample_input = load_numpy(sample_input_file)
sample_output = load_numpy(sample_output_file)

sample_input_with_batch_dim = OrderedDict(
(key, numpy.expand_dims(value, 0)) for key, value in sample_input.items()
)
outputs = session.run(None, sample_input_with_batch_dim)
if isinstance(outputs, list):
validations_sample = []
for o1, o2 in zip(outputs, sample_output.values()):
validations_sample.append(validation_function(o1, o2))
validations.append(all(validations_sample))
else:
validations.append(validation_function(outputs, sample_output))

if not all(validations):
_LOGGER.error(
f"Correctness validation failed for exported model: {onnx_model_name}. "
"The model outputs match the expected outputs "
f"only for {sum(validations)}/{len(validations)} samples "
f"(according to the validation function: {validation_function.__name__}. "
f"Some failures are expected in the case of quantized models, but not in "
f"the case of non-quantized models. If in doubt, validate the performance "
f"of the exported ONNX model using the NeuralMagic evaluation module."
)
return False

_LOGGER.info(
f"Successfully validated the exported model on all {len(validations)} samples."
)

runner = InferenceRunner(
sample_inputs=sample_inputs,
sample_outputs=sample_outputs,
onnx_file=onnx_model,
)

runner.validate_with_onnx_runtime()
return True
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

def create_model(
source_path: Union[Path, str],
batch_size: Optional[int] = None,
batch_size: Optional[int] = 1,
device: Optional[str] = None,
**kwargs,
) -> Tuple[torch.nn.Module, Dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/torch_to_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def transform(
if _PARSED_TORCH_VERSION < version.parse("1.10.0"):
kwargs["strip_doc_string"] = True
else:
kwargs["training"] = torch.onnx.TrainingMode.PRESERVE
kwargs["training"] = torch.onnx.TrainingMode.EVAL
kwargs["do_constant_folding"] = not module.training
kwargs["keep_initializers_as_inputs"] = False

Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _tensors_export_batch(

if isinstance(tensors, Iterable):
# TODO: I am breaking something here? - dbogunowicz
for index, tens in enumerate(zip(tensors)):
for index, tens in enumerate(tensors):
exported_paths.append(
tensor_export(
tens, export_dir, "{}-{:04d}".format(name_prefix, counter + index)
Expand Down
6 changes: 6 additions & 0 deletions src/sparseml/transformers/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

def create_model(
source_path: Union[Path, str],
dataset_with_labels: bool = False,
device: Optional[str] = None,
task: Optional[str] = None,
recipe: Optional[str] = None,
Expand All @@ -60,6 +61,9 @@ def create_model(
loaded_model_kwargs (any relevant objects created along with the model)
:param source_path: The path to the model
:param dataset_with_labels: Whether the allow the dataset to
have "labels" inputs or not. Text-generation datasets may
contain labels (needed for training only)
:param device: The device to use for the model and dataloader instantiation
:param task: The task to use for the model and dataloader instantiation
:param recipe: The recipe to use for the model and dataloader instantiation.
Expand Down Expand Up @@ -109,6 +113,8 @@ def create_model(
config=config,
split="validation",
)
if task in TaskNames.text_generation.value and not dataset_with_labels:
validation_dataset = validation_dataset.remove_columns("labels")

trainer = initialize_trainer(model, source_path, validation_dataset)

Expand Down
8 changes: 8 additions & 0 deletions src/sparseml/transformers/utils/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,14 @@ def text_generation_from_pretrained(
:param kwargs: keyword arguments to pass through to the AutoModel call
:return: the created model for text generation
"""
# set the config so that exported model is a decoder and does
# not take past_key_values as input
config.is_decoder = True
# whether to use past key values an input
config.use_past = False
# whether to output past key values
config.use_cache = False

if config.model_type == "opt":
# TODO: Talk to Alex whether this pathway needs to be maintained
def skip(*args, **kwargs):
Expand Down
Loading

0 comments on commit e0c1068

Please sign in to comment.