diff --git a/src/sparseml/export/export.py b/src/sparseml/export/export.py index 2a1e3014ba9..e9c5e6a600d 100644 --- a/src/sparseml/export/export.py +++ b/src/sparseml/export/export.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os from pathlib import Path from typing import Any, List, Optional, Union @@ -20,13 +21,13 @@ from sparseml.export.helpers import ( AVAILABLE_DEPLOYMENT_TARGETS, ONNX_MODEL_NAME, - apply_optimizations, create_deployment_folder, + create_export_kwargs, ) from sparseml.export.validators import validate_correctness as validate_correctness_ from sparseml.export.validators import validate_structure as validate_structure_ from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET -from sparseml.pytorch.utils.helpers import default_device +from sparseml.pytorch.utils.helpers import default_device, use_single_gpu from src.sparseml.integration_helper_functions import ( IntegrationHelperFunctions, resolve_integration, @@ -44,6 +45,7 @@ def export( opset: int = TORCH_DEFAULT_ONNX_OPSET, single_graph_file: bool = True, num_export_samples: int = 0, + batch_size: int = 1, deployment_directory_name: str = "deployment", device: str = "auto", graph_optimizations: Union[str, List[str], None] = "all", @@ -51,7 +53,7 @@ def export( validate_structure: bool = True, integration: Optional[str] = None, sample_data: Optional[Any] = None, - batch_size: Optional[int] = None, + task: Optional[str] = None, **kwargs, ): """ @@ -84,6 +86,8 @@ def export( file. Defaults to True. :param num_export_samples: The number of samples to create for the exported model. Defaults to 0. + :param batch_size: The batch size to use for exporting the data. + Defaults to None. :param deployment_directory_name: The name of the deployment directory to create for the exported model. Thus, the exported model will be saved to `target_path/deployment_directory_name`. @@ -102,7 +106,7 @@ def export( :param sample_data: Optional sample data to use for exporting the model. If not provided, a dummy input will be created for the model. Defaults to None. - :param batch_size: The batch size to use for exporting the data. + :param task: Optional task to use for exporting the model. Defaults to None. """ @@ -112,6 +116,7 @@ def export( # choose the appropriate device device = default_device() if device == "auto" else device + device = use_single_gpu(device) if "cuda" in device else device # assert the valid deployment target if deployment_target not in AVAILABLE_DEPLOYMENT_TARGETS: @@ -126,33 +131,34 @@ def export( _LOGGER.info(f"Starting export for {integration} model...") helper_functions: IntegrationHelperFunctions = ( - IntegrationHelperFunctions.load_from_registry(integration) + IntegrationHelperFunctions.load_from_registry(integration, task=task) ) _LOGGER.info("Creating model for the export...") - model, validation_dataloader = helper_functions.create_model( - source_path, batch_size, device, **kwargs + + # loaded_model_kwargs may include any objects + # that were created along with the model and are needed + # for the export + model, loaded_model_kwargs = helper_functions.create_model( + source_path, device=device, task=task, batch_size=batch_size, **kwargs ) - if validation_dataloader: - _LOGGER.info("Created validation dataloader for the export") - else: - _LOGGER.warning( - "Failed to create validation dataloader for the export. " - "Will be using the dummy (or user-provided) data instead " - "and will be not able to export samples or validate the model " - "correctness." + if loaded_model_kwargs: + _LOGGER.info( + "Created additional items that will " + f"be used for the export: {list(loaded_model_kwargs.keys())}" ) sample_data = ( - helper_functions.create_dummy_input( - validation_dataloader=validation_dataloader, **kwargs - ) + helper_functions.create_dummy_input(**loaded_model_kwargs, **kwargs) if sample_data is None else sample_data ) _LOGGER.info(f"Exporting {onnx_model_name} to {target_path}...") + + export_kwargs = create_export_kwargs(loaded_model_kwargs) + onnx_file_path = helper_functions.export( model=model, sample_data=sample_data, @@ -160,35 +166,20 @@ def export( onnx_model_name=onnx_model_name, deployment_target=deployment_target, opset=opset, + **export_kwargs, ) - _LOGGER.info(f"Successfully exported {onnx_model_name} to {target_path}...") - - _LOGGER.info( - f"Applying optimizations: {graph_optimizations} to the exported model..." - ) - apply_optimizations( - onnx_file_path=onnx_file_path, - target_optimizations=graph_optimizations, - available_optimizations=helper_functions.graph_optimizations, - single_graph_file=single_graph_file, - ) + _LOGGER.info(f"Successfully exported {onnx_model_name} to {onnx_file_path}...") if num_export_samples: _LOGGER.info(f"Exporting {num_export_samples} samples...") - if not validation_dataloader: - raise ValueError( - "To export sample inputs/outputs a data loader is needed. " - "To return a data loader provide the appropriate, integration-specific " - "arguments to `create_model` function" - ) ( input_samples, output_samples, label_samples, ) = helper_functions.create_data_samples( num_samples=num_export_samples, - data_loader=validation_dataloader, model=model, + **loaded_model_kwargs, ) export_data_samples( input_samples=input_samples, @@ -207,16 +198,29 @@ def export( source_path=source_path, target_path=target_path, deployment_directory_name=deployment_directory_name, - deployment_directory_files=helper_functions.deployment_directory_structure, + deployment_directory_files_mandatory=helper_functions.deployment_directory_files_mandatory, # noqa: E501 + deployment_directory_files_optional=helper_functions.deployment_directory_files_optional, # noqa: E501 onnx_model_name=onnx_model_name, ) + + _LOGGER.info( + f"Applying optimizations: {graph_optimizations} to the exported model..." + ) + if helper_functions.apply_optimizations is not None: + helper_functions.apply_optimizations( + exported_file_path=os.path.join(deployment_path, onnx_model_name), + optimizations=graph_optimizations, + single_graph_file=single_graph_file, + ) + if validate_structure: _LOGGER.info("Validating model structure...") validate_structure_( target_path=target_path, deployment_directory_name=deployment_directory_name, onnx_model_name=onnx_model_name, - deployment_directory_files=helper_functions.deployment_directory_structure, + deployment_directory_files_mandatory=helper_functions.deployment_directory_files_mandatory, # noqa: E501 + deployment_directory_files_optional=helper_functions.deployment_directory_files_optional, # noqa: E501 ) if validate_correctness: diff --git a/src/sparseml/export/export_data.py b/src/sparseml/export/export_data.py index 7ecf8f146a6..6c1899bd9f6 100644 --- a/src/sparseml/export/export_data.py +++ b/src/sparseml/export/export_data.py @@ -18,7 +18,7 @@ import tarfile from enum import Enum from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from tqdm import tqdm @@ -46,47 +46,11 @@ class InputsNames(Enum): filename = "inp" -def create_data_samples( - data_loader: torch.utils.data.DataLoader, - model: Optional[torch.nn.Module] = None, - num_samples: int = 1, -) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: - """ - Fetch a batch of samples from the data loader and return the inputs and outputs - - :param data_loader: The data loader to get a batch of inputs/outputs from. - :param model: The model to run the inputs through to get the outputs. - If None, the outputs will be an empty list. - :param num_samples: The number of samples to generate. Defaults to 1 - :return: The inputs and outputs as lists of torch tensors - """ - inputs, outputs, labels = [], [], [] - if model is None: - _LOGGER.warning("The model is None. The list of outputs will be empty") - for batch_num, (inputs_, labels_) in tqdm(enumerate(data_loader)): - if batch_num == num_samples: - break - if model: - outputs_ = model(inputs_) - if isinstance(outputs_, tuple): - # outputs_ contains (logits, softmax) - outputs_ = outputs_[0] - outputs.append(outputs_) - inputs.append(inputs_) - labels.append( - torch.IntTensor([labels_]) - if not isinstance(labels_, torch.Tensor) - else labels_ - ) - - return inputs, outputs, labels - - def export_data_samples( target_path: Union[Path, str], - input_samples: Optional[List["torch.Tensor"]] = None, # noqa F821 - output_samples: Optional[List["torch.Tensor"]] = None, # noqa F821 - label_samples: Optional[List["torch.Tensor"]] = None, # noqa F821 + input_samples: Optional[List[Any]] = None, + output_samples: Optional[List[Any]] = None, + label_samples: Optional[List[Any]] = None, as_tar: bool = False, ): """ @@ -116,6 +80,7 @@ def export_data_samples( :param input_samples: The input samples to save. :param output_samples: The output samples to save. + :param label_samples: The label samples to save. :param target_path: The path to save the samples to. :param as_tar: Whether to save the samples as tar files. """ @@ -124,16 +89,21 @@ def export_data_samples( [input_samples, output_samples, label_samples], [InputsNames, OutputsNames, LabelNames], ): - if samples is not None: + if len(samples) > 0: _LOGGER.info(f"Exporting {names.basename.value} to {target_path}...") - export_data_sample(samples, names, target_path, as_tar) + break_batch = isinstance(samples[0], dict) + export_data_sample(samples, names, target_path, as_tar, break_batch) _LOGGER.info( f"Successfully exported {names.basename.value} to {target_path}!" ) def export_data_sample( - samples, names: Enum, target_path: Union[Path, str], as_tar: bool = False + samples, + names: Enum, + target_path: Union[Path, str], + as_tar: bool = False, + break_batch=False, ): samples = tensors_to_device(samples, "cpu") @@ -142,9 +112,105 @@ def export_data_sample( tensors=samples, export_dir=os.path.join(target_path, names.basename.value), name_prefix=names.filename.value, + break_batch=break_batch, ) if as_tar: folder_path = os.path.join(target_path, names.basename.value) with tarfile.open(folder_path + ".tar.gz", "w:gz") as tar: tar.add(folder_path, arcname=os.path.basename(folder_path)) shutil.rmtree(folder_path) + + +def create_data_samples( + data_loader: torch.utils.data.DataLoader, + model: Optional[torch.nn.Module] = None, + num_samples: int = 1, +) -> Tuple[List[Any], List[Any], List[Any]]: + """ + Fetch a batch of samples from the data loader and return the inputs and outputs + + :param data_loader: The data loader to get a batch of inputs/outputs from. + :param model: The model to run the inputs through to get the outputs. + If None, the outputs will be an empty list. + :param num_samples: The number of samples to generate. Defaults to 1 + :return: The inputs and outputs as lists of torch tensors + """ + inputs, outputs, labels = [], [], [] + if model is None: + _LOGGER.warning("The model is None. The list of outputs will be empty") + + for batch_num, data in tqdm(enumerate(data_loader)): + if batch_num == num_samples: + break + if isinstance(data, dict): + inputs_, labels_, outputs_ = run_inference_with_dict_data( + data=data, model=model + ) + elif isinstance(data, (list, tuple)): + inputs_, labels_, outputs_ = run_inference_with_tuple_or_list_data( + data=data, model=model + ) + else: + raise ValueError( + f"Data type {type(data)} is not supported. " + f"Only dict and tuple are supported" + ) + + inputs.append(inputs_) + if outputs_ is not None: + outputs.append(outputs_) + if labels_ is not None: + labels.append( + torch.IntTensor([labels_]) + if not isinstance(labels_, torch.Tensor) + else labels_ + ) + + return inputs, outputs, labels + + +def run_inference_with_dict_data( + data: Dict[str, Any], model: Optional[torch.nn.Module] = None +) -> Tuple[Dict[str, Any], Any, Optional[Dict[str, Any]]]: + """ + Run inference on a model by inferring the appropriate + inputs from the dictionary input data. + + + :param data: The data to run inference on + :param model: The model to run inference on (optional) + :return: The inputs, labels and outputs + """ + labels = None + if model is None: + output = None + + else: + inputs = {key: value.to(model.device) for key, value in data.items()} + output_vals = model(**inputs) + output = { + name: torch.squeeze(val).detach().to("cpu") + for name, val in output_vals.items() + } + inputs = {key: value.to("cpu") for key, value in data.items()} + return inputs, labels, output + + +def run_inference_with_tuple_or_list_data( + data: Tuple[Any, Any], model: Optional[torch.nn.Module] = None +) -> Tuple[torch.Tensor, Any, Optional[torch.Tensor]]: + """ + Run inference on a model by inferring the appropriate + inputs from the tuple input data. + + :param inputs: The data to run inference on + :param model: The model to run inference on (optional) + :return: The inputs, labels and outputs + """ + # assume that + inputs, labels = data + outputs = model(inputs) if model else None + if isinstance(outputs, tuple): + # outputs_ contains (logits, softmax) + outputs = outputs[0] + return inputs, labels, outputs diff --git a/src/sparseml/export/helpers.py b/src/sparseml/export/helpers.py index c06e0139111..404b0f68ad9 100644 --- a/src/sparseml/export/helpers.py +++ b/src/sparseml/export/helpers.py @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import shutil from collections import OrderedDict from enum import Enum from pathlib import Path -from typing import Callable, List, Optional, Union - -import onnx +from typing import Any, Callable, Dict, List, Optional, Union from sparseml.exporters import ExportTargets -from sparsezoo.utils.onnx import save_onnx __all__ = [ @@ -30,17 +28,50 @@ "create_deployment_folder", "AVAILABLE_DEPLOYMENT_TARGETS", "ONNX_MODEL_NAME", + "create_export_kwargs", ] AVAILABLE_DEPLOYMENT_TARGETS = [target.value for target in ExportTargets] ONNX_MODEL_NAME = "model.onnx" ONNX_DATA_NAME = "model.data" +_LOGGER = logging.getLogger(__name__) + + +def create_export_kwargs( + loaded_model_kwargs: Dict[str, Any], export_target: str = "deepsparse" +) -> Dict[str, Any]: + """ + Retrieve the export kwargs from the loaded model kwargs. + + The export kwargs are the kwargs that are passed to the export function. + Given the loaded model kwargs and the export_target, one can define which + loaded_model_kwargs should be routed to the export kwargs. + + :param loaded_model_kwargs: The loaded model kwargs. + :param export_target: The export target. + :return: The export kwargs. + """ + + if export_target not in AVAILABLE_DEPLOYMENT_TARGETS: + raise ValueError( + f"Export target {export_target} not in " + f"available targets {AVAILABLE_DEPLOYMENT_TARGETS}" + ) + + export_kwargs = {} + input_names = loaded_model_kwargs.get("input_names") + if input_names is not None: + export_kwargs["input_names"] = input_names + + return export_kwargs + def create_deployment_folder( source_path: Union[Path, str], target_path: Union[Path, str], - deployment_directory_files: List[str], + deployment_directory_files_mandatory: List[str], + deployment_directory_files_optional: Optional[List[str]] = None, deployment_directory_name: str = "deployment", onnx_model_name: Optional[str] = None, ) -> str: @@ -58,9 +89,12 @@ def create_deployment_folder( :param target_path: The path to the target folder. :param deployment_directory_name: The name of the deployment directory. The files will be copied to target_path/deployment_directory_name. - :param deployment_directory_files: The list of files to copy to the deployment - directory. If the file is an ONNX model (or ONNX data file), the file will - be copied from target_path. Else, the file will be copied from source_path. + :param deployment_directory_files_mandatory: The mandatory list of files + to copy to the deployment directory. If the file is an ONNX model + (or ONNX data file), the file will be copied from target_path. + Else, the file will be copied from source_path. + :param deployment_directory_files_optional: The optional list of files + to copy to the deployment directory. :param onnx_model_name: The name of the ONNX model file. If not specified, defaults to ONNX_MODEL_NAME. :return: The path to the deployment folder. @@ -71,32 +105,65 @@ def create_deployment_folder( shutil.rmtree(deployment_folder_dir) os.makedirs(deployment_folder_dir, exist_ok=True) - # copy over the expected files - for file_name in deployment_directory_files: - if file_name == ONNX_MODEL_NAME: - # attempting to move the ONNX model file - # (potentially together with the ONNX data file) - # from target_path to target_path/deployment_folder_dir + deployment_directory_files_optional = deployment_directory_files_optional or [] - # takes into consideration potentially custom ONNX model name - onnx_model_name = ( - ONNX_MODEL_NAME if onnx_model_name is None else onnx_model_name - ) + for file_name in deployment_directory_files_mandatory: + move_mandatory_deployment_files( + file_name, source_path, target_path, onnx_model_name, deployment_folder_dir + ) - _move_onnx_model( - onnx_model_name=onnx_model_name, - src_path=target_path, - target_path=deployment_folder_dir, - ) + for file_name in deployment_directory_files_optional: + move_optional_deployment_files(file_name, source_path, deployment_folder_dir) - else: - _copy_file_or_directory( - src=os.path.join(source_path, file_name), - target=os.path.join(deployment_folder_dir, file_name), - ) return deployment_folder_dir +def move_mandatory_deployment_files( + file_name: str, + source_path: Union[Path, str], + target_path: Union[Path, str], + onnx_model_name: str, + deployment_folder_dir: Union[Path, str], +): + if file_name == ONNX_MODEL_NAME: + # attempting to move the ONNX model file + # (potentially together with the ONNX data file) + # from target_path to target_path/deployment_folder_dir + + # takes into consideration potentially custom ONNX model name + onnx_model_name = ( + ONNX_MODEL_NAME if onnx_model_name is None else onnx_model_name + ) + + _move_onnx_model( + onnx_model_name=onnx_model_name, + src_path=target_path, + target_path=deployment_folder_dir, + ) + + else: + _copy_file_or_directory( + src=os.path.join(source_path, file_name), + target=os.path.join(deployment_folder_dir, file_name), + ) + + +def move_optional_deployment_files( + file_name: str, + source_path: Union[Path, str], + deployment_folder_dir: Union[Path, str], +): + if os.path.exists(os.path.join(source_path, file_name)): + _copy_file_or_directory( + src=os.path.join(source_path, file_name), + target=os.path.join(deployment_folder_dir, file_name), + ) + else: + _LOGGER.warning( + f"Optional file {file_name} not found in source path {source_path}" + ) + + class GraphOptimizationOptions(Enum): """ Holds the string names of the graph optimization options. @@ -127,27 +194,28 @@ def apply_optimizations( :param single_graph_file: Whether to save the optimized graph to a single file or split it into multiple files. By default, it is True. """ - optimizations: List[Callable] = resolve_graph_optimizations( + optimizations: Dict[str, Callable] = resolve_graph_optimizations( available_optimizations=available_optimizations, optimizations=target_optimizations, ) - onnx_model = onnx.load(onnx_file_path) - - for optimization in optimizations: - onnx_model = optimization(onnx_model) - - if single_graph_file: - save_onnx(onnx_model, onnx_file_path) - return + for name, optimization in optimizations.items(): + _LOGGER.info(f"Attempting to apply optimization: {name}... ") + applied = optimization(onnx_file_path) + if applied: + _LOGGER.info( + f"Optimization: {name} has been successfully " + f"applied to the ONNX model: {onnx_file_path}" + ) - save_onnx_multiple_files(onnx_model) + if not single_graph_file: + save_onnx_multiple_files(onnx_file_path) def resolve_graph_optimizations( optimizations: Union[str, List[str]], available_optimizations: Optional[OrderedDict[str, Callable]] = None, -) -> List[Callable]: +) -> Dict[str, Callable]: """ Get the optimization functions to apply to the onnx model. @@ -159,21 +227,17 @@ def resolve_graph_optimizations( that specifies the set of optimizations to apply. If is string, refer to the `GraphOptimizationOptions` enum for the available options. - return The list of optimization functions to apply. + :return: The optimization functions to apply to the onnx model. """ if isinstance(optimizations, str): if optimizations == GraphOptimizationOptions.none.value: - return [] + return {} elif optimizations == GraphOptimizationOptions.all.value: - return ( - list(available_optimizations.values()) - if available_optimizations is not None - else [] - ) + return available_optimizations or {} else: raise KeyError(f"Unknown graph optimization option: {optimizations}") elif isinstance(optimizations, list): - return [available_optimizations[optimization] for optimization in optimizations] + return {name: available_optimizations[name] for name in optimizations} else: raise KeyError(f"Unknown graph optimization option: {optimizations}") diff --git a/src/sparseml/export/validators.py b/src/sparseml/export/validators.py index 5d07207bf67..9958468e514 100644 --- a/src/sparseml/export/validators.py +++ b/src/sparseml/export/validators.py @@ -15,7 +15,7 @@ import logging import os.path from pathlib import Path -from typing import List, Union +from typing import List, Optional, Union from sparseml.export.export_data import InputsNames, LabelNames, OutputsNames from sparseml.export.helpers import ONNX_MODEL_NAME @@ -32,7 +32,8 @@ def validate_structure( target_path: Union[str, Path], deployment_directory_name: str, onnx_model_name: str, - deployment_directory_files: List[str], + deployment_directory_files_mandatory: List[str], + deployment_directory_files_optional: Optional[List[str]] = None, ): """ Validates the structure of the targe_path by @@ -42,24 +43,35 @@ def validate_structure( :param target_path: The directory where the exported files are stored. :param deployment_directory_name: The name of the deployment directory. :param onnx_model_name: The name of the ONNX model. - :param deployment_directory_files: The list of files that should be present - in the deployment directory. + :param deployment_directory_files_mandatory: The list of files that + should be present in the deployment directory. + :param deployment_directory_files_optional: The list of files that + can be optionally present in the deployment directory. """ sample_files = {InputsNames, OutputsNames, LabelNames} - deployment_directory_files = [ + # account for the potentially custom ONNX model name + deployment_directory_files_mandatory = [ onnx_model_name if file_name == ONNX_MODEL_NAME else file_name - for file_name in deployment_directory_files + for file_name in deployment_directory_files_mandatory ] - - mandatory_files = { + # obtain full paths + deployment_directory_files_mandatory = { + os.path.join(target_path, deployment_directory_name, file_name) + for file_name in deployment_directory_files_mandatory + } + deployment_directory_files_optional = { os.path.join(target_path, deployment_directory_name, file_name) - for file_name in deployment_directory_files + for file_name in deployment_directory_files_optional or [] } + + # obtain full paths for the potential sample files optional_files = { os.path.join(target_path, name.basename.value) for name in sample_files } - missing_mandatory_files = check_file_presence(mandatory_files) + optional_files.update(deployment_directory_files_optional) + + missing_mandatory_files = check_file_presence(deployment_directory_files_mandatory) missing_optional_files = check_file_presence(optional_files) if missing_optional_files: diff --git a/src/sparseml/integration_helper_functions.py b/src/sparseml/integration_helper_functions.py index fee55b87fb6..602a34aa190 100644 --- a/src/sparseml/integration_helper_functions.py +++ b/src/sparseml/integration_helper_functions.py @@ -32,10 +32,12 @@ class Integrations(Enum): """ image_classification = "image-classification" + transformers = "transformers" def resolve_integration( - source_path: Union[Path, str], integration: Optional[str] = None + source_path: Union[Path, str], + integration: Optional[str] = None, ) -> str: """ Resolve the integration to use. @@ -56,15 +58,23 @@ def resolve_integration( from sparseml.pytorch.image_classification.utils.helpers import ( is_image_classification_model, ) + from sparseml.transformers.utils.helpers import is_transformer_model if ( integration == Integrations.image_classification.value or is_image_classification_model(source_path) ): - # import to register the image_classification integration helper functions import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401 return Integrations.image_classification.value + + elif integration == Integrations.transformers.value or is_transformer_model( + source_path + ): + + import sparseml.transformers.integration_helper_functions # noqa F401 + + return Integrations.transformers.value else: raise ValueError( f"Could not infer integration from source_path:\n{source_path}\n" @@ -81,36 +91,24 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): integration. """ - create_model: Optional[ - Callable[ - Tuple[Union[str, Path], Optional[int], str, Optional[Dict[str, Any]]], - Tuple[ - "torch.nn.Module", # noqa F821 - Optional["torch.utils.data.Dataloader"], # noqa F821 - ], - ] + create_model: Callable[ + [Union[str, Path], ...], + Tuple[ + "torch.nn.Module", # noqa F821 + Optional[Dict[str, Any]], + ], ] = Field( description="A function that takes: " "- a source path to a PyTorch model " - "- a batch size " - "- a device name " - "- (optionally) a dictionary of additional arguments" + "- (optionally) additional arguments" "and returns: " "- a (sparse) PyTorch model " - "- (optionally) a data loader " + "- (optionally) loaded_model_kwargs " + "(any relevant objects created along with the model)" ) - create_dummy_input: Optional[ - Callable[ - Tuple[ - Optional["torch.utils.data.Dataloader"], # noqa F821 - Optional[Dict[str, Any]], - ], - "torch.Tensor", # noqa F821 - ] - ] = Field( + create_dummy_input: Callable[..., "torch.Tensor"] = Field( # noqa F821 description="A function that takes: " - "- (optionally) a data loader " - "- (optionally) a dictionary of additional arguments" + "- appropriate arguments " "and returns: " "- a dummy input for the model (a torch.Tensor) " ) @@ -126,33 +124,38 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): "and returns the path to the exported model", default=export_model, ) - graph_optimizations: Optional[Dict[str, Callable]] = Field( - description="A mapping from names to graph optimization functions " + apply_optimizations: Optional[Callable[..., None]] = Field( + description="A function that takes:" + " - path to the exported model" + " - names of the optimizations to apply" + " and applies the optimizations to the model", ) create_data_samples: Callable[ - Tuple[ - Optional["torch.nn.Module"], # noqa F821 - "torch.utils.data.DataLoader", # noqa F821 - int, - ], + Tuple[Optional["torch.nn.Module"], int, Optional[Dict[str, Any]]], # noqa F821 Tuple[ List["torch.Tensor"], # noqa F821 Optional[List["torch.Tensor"]], # noqa F821 - List["torch.Tensor"], # noqa F821 + Optional[List["torch.Tensor"]], # noqa F821 ], ] = Field( default=create_data_samples_, description="A function that takes: " " - (optionally) a (sparse) PyTorch model " - " - a data loader " " - the number of samples to generate " + " - (optionally) loaded_model_kwargs " + "(any relevant objects created along with the model) " "and returns: " - " - the inputs, labels and (optionally) outputs as torch tensors ", + " - the inputs, (optionally) labels and (optionally) outputs as torch tensors ", ) - deployment_directory_structure: List[str] = Field( + deployment_directory_files_mandatory: List[str] = Field( description="A list that describes the " - "expected files of the deployment directory", + "mandatory expected files of the deployment directory", default=["model.onnx"], ) + + deployment_directory_files_optional: Optional[List[str]] = Field( + description="A list that describes the " + "optional expected files of the deployment directory", + ) diff --git a/src/sparseml/pytorch/image_classification/integration_helper_functions.py b/src/sparseml/pytorch/image_classification/integration_helper_functions.py index 4d90908c4a3..b90585ef446 100644 --- a/src/sparseml/pytorch/image_classification/integration_helper_functions.py +++ b/src/sparseml/pytorch/image_classification/integration_helper_functions.py @@ -13,11 +13,12 @@ # limitations under the License. import os from pathlib import Path -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch from pydantic import Field +from src.sparseml.export.export_data import create_data_samples as create_data_samples_ from src.sparseml.integration_helper_functions import ( IntegrationHelperFunctions, Integrations, @@ -36,12 +37,13 @@ def create_model( source_path: Union[Path, str], - batch_size: Optional[int], - device: Optional[str], + batch_size: Optional[int] = None, + device: Optional[str] = None, **kwargs, -) -> Tuple[torch.nn.Module, Optional[torch.utils.data.DataLoader]]: +) -> Tuple[torch.nn.Module, Dict[str, Any]]: """ - A contract to create a model and optionally a validation dataloader + A contract to create a model and optional dictionary of + loaded_model_kwargs (any relevant objects created along with the model) :param source_path: The path to the model :param batch_size: The batch size to use for the dataloader creation @@ -49,7 +51,8 @@ def create_model( :return: A tuple of the - torch model - - (optionally) validation dataloader + - (optionally) loaded_model_kwargs + (any relevant objects created along with the model) """ checkpoint_path = ( os.path.join(source_path, "model.pth") @@ -92,39 +95,55 @@ def create_model( checkpoint_path=checkpoint_path, **kwargs ) - return model, validation_dataloader + return model, dict(validation_dataloader=validation_dataloader) def create_dummy_input( validation_dataloader: Optional[torch.utils.data.DataLoader] = None, - **kwargs: Any, + image_size: Optional[int] = None, + **kwargs, ) -> torch.Tensor: """ A contract to create a dummy input for a model :param validation_dataloader: The validation dataloader to get a batch from. If None, a fake batch will be created + :param image_size: The image size to use for the dummy input :return: The dummy input as a torch tensor """ if not validation_dataloader: # create fake data for export - batch_size = kwargs.get("batch_size", 1) - image_size = kwargs.get("image_size") if image_size is None: raise ValueError( "In the absence of validation_dataloader, the " "image_size must be provided to create a dummy input" ) - validation_dataloader = [[torch.randn(batch_size, 3, image_size, image_size)]] + validation_dataloader = [[torch.randn(1, 3, image_size, image_size)]] return next(iter(validation_dataloader))[0] +def create_data_samples( + num_samples: int, + validation_dataloader: Optional[torch.utils.data.DataLoader] = None, + model: Optional["torch.nn.Module"] = None, + **kwargs, +): + if validation_dataloader is None: + raise ValueError( + "Attempting to create data samples without a validation dataloader." + ) + + return create_data_samples_( + data_loader=validation_dataloader, model=model, num_samples=num_samples + ) + + @IntegrationHelperFunctions.register(name=Integrations.image_classification.value) class ImageClassification(IntegrationHelperFunctions): - - create_model: Callable[ - ..., Tuple[torch.nn.Module, Optional[torch.utils.data.DataLoader]] - ] = Field(default=create_model) + create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field( + default=create_model + ) create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input) + create_data_samples: Callable = Field(create_data_samples) diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index f2f5eccc7d0..af23894e5dd 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -137,6 +137,15 @@ def default_device() -> str: return "cuda:{}".format(",".join(device_ids)) +def use_single_gpu(device: str) -> str: + """ + return: the first gpu in the device string if multiple are available + """ + if "cuda" not in device: + raise ValueError("use_single_gpu should only be called on cuda devices") + return device.split(",")[0] + + def device_of(inputs: Any): if isinstance(inputs, Tensor): return inputs.device @@ -538,7 +547,8 @@ def _tensors_export_batch( return if isinstance(tensors, Iterable): - for index, tens in enumerate(zip(*tensors)): + # TODO: I am breaking something here? - dbogunowicz + for index, tens in enumerate(zip(tensors)): exported_paths.append( tensor_export( tens, export_dir, "{}-{:04d}".format(name_prefix, counter + index) diff --git a/src/sparseml/transformers/integration_helper_functions.py b/src/sparseml/transformers/integration_helper_functions.py new file mode 100644 index 00000000000..17db6ea2666 --- /dev/null +++ b/src/sparseml/transformers/integration_helper_functions.py @@ -0,0 +1,211 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from pydantic import Field +from transformers import AutoTokenizer + +from sparseml.export.helpers import apply_optimizations as apply_optimizations_onnx +from sparseml.transformers.sparsification.trainer import Trainer +from sparseml.transformers.utils.helpers import ( + MANDATORY_DEPLOYMENT_FILES, + NLG_TOKENIZER_FILES, + OPTIONAL_DEPLOYMENT_FILES, + TaskNames, +) +from sparseml.transformers.utils.load_task_dataset import load_task_dataset +from sparseml.transformers.utils.optimizations import apply_kv_cache_injection +from src.sparseml.export.export_data import create_data_samples as create_data_samples_ +from src.sparseml.integration_helper_functions import ( + IntegrationHelperFunctions, + Integrations, +) +from src.sparseml.transformers.utils.initializers import ( + _parse_data_args, + initialize_config, + initialize_model, + initialize_tokenizer, + initialize_trainer, + resolve_sequence_length, +) + + +_LOGGER = logging.getLogger(__name__) + + +def create_model( + source_path: Union[Path, str], + device: Optional[str] = None, + task: Optional[str] = None, + **kwargs, +) -> Tuple[torch.nn.Module, Dict[str, Any]]: + """ + A contract to create a model and optional dictionary of + loaded_model_kwargs (any relevant objects created along with the model) + + :param source_path: The path to the model + :param device: The device to use for the model and dataloader instantiation + :param task: The task to use for the model and dataloader instantiation + + :return: A tuple of the + - torch model + - (optionally) loaded_model_kwargs + (any relevant objects created along with the model) + """ + config_args = kwargs.get("config_args", {}) + sequence_length = kwargs.get("sequence_length", None) + data_args = kwargs.get("data_args", {}) + trust_remote_code = kwargs.get("trust_remote_code", False) + + if task is None: + raise ValueError("To create a transformer model, a task must be specified") + task = task.replace("_", "-") + + if not trust_remote_code: + _LOGGER.warning( + "trust_remote_code is set to False. It is possible, " + "that the model will not be loaded correctly." + ) + + config = initialize_config(source_path, trust_remote_code, **config_args) + sequence_length = sequence_length or resolve_sequence_length(config) + tokenizer = initialize_tokenizer(source_path, sequence_length, task) + model = initialize_model( + model_path=source_path, + task=task, + config=config, + trust_remote_code=trust_remote_code, + device=device, + ) + + data_args = _parse_data_args(data_args) + + if data_args: + dataset = load_task_dataset( + task=task, + tokenizer=tokenizer, + data_args=data_args, + model=model, + config=config, + ) + validation_dataset = dataset.get("validation") + + else: + validation_dataset = None + + model.train() + trainer = initialize_trainer(model, source_path, validation_dataset) + model.eval() + + return model, dict( + trainer=trainer, + tokenizer=tokenizer, + input_names=list(next(trainer._get_fake_dataloader(1, tokenizer)).keys()), + ) + + +def create_dummy_input( + trainer: Optional[Trainer] = None, + tokenizer: Optional[AutoTokenizer] = None, + **kwargs, +) -> torch.Tensor: + if trainer.eval_dataset is not None: + data_loader = trainer.get_eval_dataloader() + else: + if not tokenizer: + raise ValueError( + "Tokenizer is needed to generate " + "fake sample inputs when the trainer is " + "not initialized with an eval dataset" + ) + data_loader = trainer._get_fake_dataloader(num_samples=1, tokenizer=tokenizer) + return next(iter(data_loader)) + + +def create_data_samples( + num_samples: int, + trainer: Trainer, + model: Optional["torch.nn.Module"] = None, + **kwargs, +): + if kwargs.get("batch_size"): + _LOGGER.info( + "For exporting samples for transformers integration," + "batch size is ignored (equal to 1)" + ) + if trainer.eval_dataset is None: + raise ValueError( + "Attempting to create data samples without an eval dataloader. " + "Initialize a trainer with an eval dataset" + ) + + return create_data_samples_( + data_loader=trainer.get_eval_dataloader(), model=model, num_samples=num_samples + ) + + +def apply_optimizations_generative_transformer( + exported_file_path: Union[str, Path], + optimizations: Union[str, List[str]], + single_graph_file: bool = True, +): + + if exported_file_path.endswith(".onnx"): + available_optimizations = dict(kv_cache_injection=apply_kv_cache_injection) + apply_optimizations_onnx( + onnx_file_path=exported_file_path, + target_optimizations=optimizations, + available_optimizations=available_optimizations, + single_graph_file=single_graph_file, + ) + else: + raise NotImplementedError( + "Applying optimizations is only supported for ONNX files" + ) + + +@IntegrationHelperFunctions.register(name=Integrations.transformers.value) +class Transformers(IntegrationHelperFunctions): + def __init__(self, *args, **kwargs): + super().__init__() + task = kwargs.get("task") + if task is None: + _LOGGER.warning("The task for transformers is not specified.") + elif task in TaskNames.text_generation.value: + # if the task is text generation, alter the default attributes + # to reflect the idiosyncrasies for text generation + self.apply_optimizations = apply_optimizations_generative_transformer + self.deployment_directory_files_mandatory = list( + MANDATORY_DEPLOYMENT_FILES.union(NLG_TOKENIZER_FILES) + ) + else: + _LOGGER.info( + "Fetching default helper functions for transformers integration" + ) + + create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field( + default=create_model + ) + create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input) + create_data_samples: Callable = Field(create_data_samples) + deployment_directory_files_mandatory: List[str] = Field( + default=list(MANDATORY_DEPLOYMENT_FILES) + ) + deployment_directory_files_optional: List[str] = Field( + default=list(OPTIONAL_DEPLOYMENT_FILES) + ) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index f28f04cc560..96e2767e7fc 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -22,12 +22,10 @@ import math import os import warnings -from contextlib import suppress from dataclasses import asdict from typing import Any, Dict, List, Optional, Tuple, Union import datasets -import numpy import torch from torch import distributed as dist from torch.nn import Module @@ -499,90 +497,6 @@ def log_model_sparsification(self): f"all sparsification info: {sparsification_info}" ) - def save_sample_inputs_outputs( - self, - num_samples_to_export: int = 100, - output_dir: Optional[str] = None, - tokenizer: Optional[Any] = None, - ): - """ - Save sample inputs/outputs/labels in save_dir as .npz arrays - - :param num_samples_to_export: Number of samples to export. - Defaults to 100 - :param output_dir: The directory to store sample inputs and outputs in - :param tokenizer: if eval and train dataset cannot be generated, then - the tokenizer is used to generate fake inputs - """ - num_samples = 0 - - if output_dir is None: - output_dir = ( - self.args.output_dir if hasattr(self.args, "output_dir") else "" - ) - - sample_in_dir = os.path.join(output_dir, "sample-inputs") - sample_out_dir = os.path.join(output_dir, "sample-outputs") - - os.makedirs(sample_in_dir, exist_ok=True) - os.makedirs(sample_out_dir, exist_ok=True) - device = self.model.device - - dataloader = None - try: - dataloader = self.get_eval_dataloader() - except Exception: - with suppress(ValueError): - dataloader = self.get_train_dataloader() - - if not dataloader and not tokenizer: - raise ValueError( - "tokenizer is needed to generate fake sample inputs when Trainer is " - "not initialized with a train or eval dataset" - ) - if dataloader is None: - # we have the tokenizer so use it - dataloader = self._get_fake_dataloader( - num_samples=num_samples_to_export, tokenizer=tokenizer - ) - - _LOGGER.info( - f"Exporting {num_samples_to_export} samples to " - f"{os.path.abspath(output_dir)}" - ) - for _, sample_batch in enumerate(dataloader): - sample_batch.pop("labels", None) - input_names = list(sample_batch.keys()) - - for input_vals in zip(*sample_batch.values()): - input_feed = {k: v.to("cpu") for k, v in zip(input_names, input_vals)} - model_inputs = { - k: input_feed[k].to(device).reshape(1, -1) for k in input_feed - } - output_vals = self.model(**model_inputs) - output_dict = { - name: torch.squeeze(val).detach().to("cpu") - for name, val in output_vals.items() - } - file_idx = f"{num_samples}".zfill(4) - - sample_input_filename = os.path.join( - f"{sample_in_dir}", f"inp-{file_idx}.npz" - ) - numpy.savez(sample_input_filename, **input_feed) - - sample_output_filename = os.path.join( - f"{sample_out_dir}", f"out-{file_idx}.npz" - ) - numpy.savez(sample_output_filename, **output_dict) - num_samples += 1 - - if num_samples >= num_samples_to_export: - break - if num_samples >= num_samples_to_export: - break - _LOGGER.info(f"Exported {num_samples_to_export} samples to {output_dir}") - def _extract_metadata( self, metadata_args: List[str], @@ -789,7 +703,7 @@ def _get_fake_dataloader( num_samples: int, tokenizer: "PreTrainedTokenizerBase", # noqa: F821 ): - # Rearrange inputs' keys to match those defined by model foward func, which + # Rearrange inputs' keys to match those defined by model forward func, which # seem to define how the order of inputs is determined in the exported model forward_args_spec = inspect.getfullargspec(self.model.__class__.forward) synthetic_input = self._get_fake_input( diff --git a/src/sparseml/transformers/utils/__init__.py b/src/sparseml/transformers/utils/__init__.py index 4647a435f12..16bfb5062e4 100644 --- a/src/sparseml/transformers/utils/__init__.py +++ b/src/sparseml/transformers/utils/__init__.py @@ -17,7 +17,6 @@ """ # flake8: noqa - from .helpers import * from .metrics import * from .model import * diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index c796f5a39ac..acea82bd101 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -16,22 +16,63 @@ Helper variables and functions for integrating SparseML with huggingface/transformers flows """ + import logging import os -from typing import Optional +from enum import Enum +from pathlib import Path +from typing import Optional, Union from transformers.trainer_utils import get_last_checkpoint +from sparseml.export.helpers import ONNX_MODEL_NAME from sparsezoo import setup_model _LOGGER: logging.Logger = logging.getLogger(__name__) -__all__ = ["RECIPE_NAME", "save_zoo_directory", "detect_last_checkpoint"] +__all__ = [ + "RECIPE_NAME", + "save_zoo_directory", + "detect_last_checkpoint", + "TaskNames", + "is_transformer_model", +] + + +class TaskNames(Enum): + mlm = {"masked-language-modeling", "mlm"} + qa = {"question-answering", "qa"} + token_classification = {"token-classification", "ner"} + text_classification = { + "text-classification", + "sentiment-analysis", + "sequence-classification", + "glue", + } + text_generation = {"text-generation"} RECIPE_NAME = "recipe.yaml" +MANDATORY_DEPLOYMENT_FILES = { + ONNX_MODEL_NAME, + "tokenizer_config.json", + "config.json", +} +NLG_TOKENIZER_FILES = {"special_tokens_map.json", "vocab.json", "merges.txt"} +OPTIONAL_DEPLOYMENT_FILES = {"tokenizer.json", "tokenizer.model"} + + +def is_transformer_model(source_path: Union[Path, str]) -> bool: + """ + :param source_path: The path to the model + :return: Whether the model is a transformers model or not + """ + if not os.path.isdir(source_path): + raise ValueError(f"Path {source_path} is not a valid directory") + expected_files = MANDATORY_DEPLOYMENT_FILES.difference({ONNX_MODEL_NAME}) + return expected_files.issubset(os.listdir(source_path)) def save_zoo_directory( diff --git a/src/sparseml/transformers/utils/initializers.py b/src/sparseml/transformers/utils/initializers.py new file mode 100644 index 00000000000..8f14375931e --- /dev/null +++ b/src/sparseml/transformers/utils/initializers.py @@ -0,0 +1,197 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Functionality for initializing a transformer model from a given path +""" + +import logging +import math +import os +from pathlib import Path +from typing import Any, Optional, Union + +from transformers import AutoConfig, AutoModel, AutoTokenizer, TrainingArguments + +from sparseml.optim import parse_recipe_variables +from sparseml.transformers.sparsification import Trainer +from sparseml.transformers.utils.helpers import TaskNames +from sparseml.transformers.utils.load_task_model import load_task_model + + +__all__ = [ + "initialize_model", + "initialize_tokenizer", + "initialize_trainer", + "initialize_config", + "resolve_sequence_length", +] + +_LOGGER = logging.getLogger(__name__) + + +def initialize_config( + model_path: Union[str, Path], trust_remote_code: bool = False, **config_args +) -> AutoConfig: + """ + Initialize a config from a given path + + :param model_path: the path to the model to load + :param trust_remote_code: True to trust remote code when loading the model, + False otherwise + :param config_args: additional arguments to pass to the config + :return: the loaded config + """ + config = AutoConfig.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + **config_args, + ) + return config + + +def initialize_tokenizer( + model_path: Union[str, Path], sequence_length: int, task: str +) -> AutoTokenizer: + """ + Initialize a tokenizer from a given path + + :param model_path: the path to the model to load + :param sequence_length: the sequence length to use for the tokenizer + :param task: the task to load the tokenizer for + :return: the loaded tokenizer + """ + + tokenizer = AutoTokenizer.from_pretrained( + model_path, model_max_length=sequence_length + ) + if task in TaskNames.text_generation.value: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def initialize_model( + model_path: Union[str, Path], + task: str, + config: AutoConfig, + trust_remote_code: bool = False, + device: Optional[str] = None, +) -> AutoModel: + """ + Initialize a model from a given path + + :param model_path: the path to the model to load + :param task: the task to load the model for + :param config: the config to use for the model + :param trust_remote_code: True to trust remote code when loading the model, + False otherwise + :param device: the device to load the model on. If None, will load on CPU + :return: the loaded model + """ + model = load_task_model( + task=task, + model_path=model_path, + config=config, + trust_remote_code=trust_remote_code, + ) + if device: + model.to(device) + return model + + +def initialize_trainer( + model: AutoModel, + model_path: Union[str, Path], + validation_dataset: Optional[Any] = None, +) -> Trainer: + """ + Initialize a trainer. This will apply the structure dictated by + any of the recipes stored in the model_path + + :param model: the model to initialize the trainer with + :param model_path: the path to the model to load + :param validation_dataset: the validation dataset to use for the trainer + :return: the initialized trainer + """ + + training_args = TrainingArguments( + output_dir=os.path.dirname(model_path), use_cpu=(model.device.type == "cpu") + ) + + trainer = Trainer( + model=model, + args=training_args, + model_state_path=model_path, + eval_dataset=validation_dataset, + recipe=None, + recipe_args=None, + teacher=None, + ) + applied = trainer.apply_manager(epoch=math.inf, checkpoint=None) + + if not applied: + _LOGGER.warning( + f"No recipes were applied for {model_path}, " + "check to make sure recipe(s) are stored in the model_path" + ) + else: + trainer.finalize_manager() + num_stages = 0 + if trainer.manager: + num_stages += trainer.manager.num_stages() + if trainer.arch_manager: + num_stages += trainer.arch_manager.num_stages() + + msg = ( + "an unstaged recipe" + if num_stages == 1 + else f"a staged recipe with {num_stages} stages" + ) + _LOGGER.info(f"Applied {msg} to the model at {model_path}") + + return trainer + + +def resolve_sequence_length(config: AutoConfig) -> int: + """ + Resolve the sequence length from the config + + :param config: the config to resolve the sequence length from + :return: the sequence length + """ + if hasattr(config, "max_position_embeddings"): + sequence_length = config.max_position_embeddings + + elif hasattr(config, "max_seq_len"): + sequence_length = config.max_seq_len + else: + raise ValueError( + "Could not infer a default sequence length " + "from the HF transformers config. Please specify " + "the sequence length with --sequence_length" + ) + _LOGGER.info( + f"Using default sequence length of {sequence_length} " + "(inferred from HF transformers config) " + ) + return sequence_length + + +def _parse_data_args(data_args): + try: + return parse_recipe_variables(data_args) + except ValueError as parse_error: + message = str(parse_error).replace("recipe_args", "data_args") + if "recipe variables" in message: + message = message.replace("recipe variables", "data_args") + raise ValueError(message) diff --git a/src/sparseml/transformers/utils/load_task_dataset.py b/src/sparseml/transformers/utils/load_task_dataset.py new file mode 100644 index 00000000000..cb718b02696 --- /dev/null +++ b/src/sparseml/transformers/utils/load_task_dataset.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +from torch.nn import Module +from transformers import AutoConfig, AutoTokenizer + +from sparseml.transformers.utils.helpers import TaskNames + + +__all__ = ["load_task_dataset"] + + +def load_task_dataset( + task: str, + tokenizer: AutoTokenizer, + data_args: Dict[str, Any], + model: Module, + config: Optional[AutoConfig] = None, +) -> Any: + """ + + Load a dataset for a given task. + + :param task: the task a dataset being loaded for + :param tokenizer: the tokenizer to use for the dataset + :param data_args: additional data args used to create a `DataTrainingArguments` + instance for fetching the dataset + :param model: the model to use for the dataset + :param config: the config to use for the dataset + :return: the dataset for the given task + """ + + if task in TaskNames.mlm.value: + from sparseml.transformers.masked_language_modeling import ( + DataTrainingArguments, + get_tokenized_mlm_dataset, + ) + + data_training_args = DataTrainingArguments(**data_args) + return get_tokenized_mlm_dataset( + data_args=data_training_args, tokenizer=tokenizer + ) + + if task in TaskNames.qa.value: + from sparseml.transformers.question_answering import ( + DataTrainingArguments, + get_tokenized_qa_dataset, + ) + + data_training_args = DataTrainingArguments(**data_args) + return get_tokenized_qa_dataset( + data_args=data_training_args, tokenizer=tokenizer + ) + + if task in TaskNames.token_classification.value: + from sparseml.transformers.token_classification import ( + DataTrainingArguments, + get_tokenized_token_classification_dataset, + ) + + data_training_args = DataTrainingArguments(**data_args) + return get_tokenized_token_classification_dataset( + data_args=data_training_args, tokenizer=tokenizer, model=model or config + ) + + if task in TaskNames.text_classification.value: + from sparseml.transformers.text_classification import ( + DataTrainingArguments, + get_tokenized_text_classification_dataset, + ) + + data_training_args = DataTrainingArguments(**data_args) + + return get_tokenized_text_classification_dataset( + data_args=data_training_args, + tokenizer=tokenizer, + model=model, + config=config, + ) + + raise ValueError(f"unrecognized task given of {task}") diff --git a/src/sparseml/transformers/utils/load_task_model.py b/src/sparseml/transformers/utils/load_task_model.py new file mode 100644 index 00000000000..cb417a51960 --- /dev/null +++ b/src/sparseml/transformers/utils/load_task_model.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from torch.nn import Module + +from sparseml.transformers.utils.helpers import TaskNames +from sparseml.transformers.utils.model import SparseAutoModel + + +__all__ = ["load_task_model"] + + +def load_task_model( + task: str, model_path: str, config: Any, trust_remote_code: bool = False +) -> Module: + if task in TaskNames.mlm.value: + return SparseAutoModel.masked_language_modeling_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + if task in TaskNames.qa.value: + return SparseAutoModel.question_answering_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + if task in TaskNames.text_classification.value: + return SparseAutoModel.text_classification_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + if task in TaskNames.token_classification.value: + return SparseAutoModel.token_classification_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + if task in TaskNames.text_generation.value: + return SparseAutoModel.text_generation_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + raise ValueError(f"unrecognized task given of {task}") diff --git a/src/sparseml/transformers/utils/optimizations.py b/src/sparseml/transformers/utils/optimizations.py new file mode 100644 index 00000000000..20b58775b10 --- /dev/null +++ b/src/sparseml/transformers/utils/optimizations.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from pathlib import Path +from typing import Union + +import onnx + +from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector + + +__all__ = ["apply_kv_cache_injection"] + +_LOGGER = logging.getLogger(__name__) + + +def apply_kv_cache_injection(onnx_model_path: Union[str, Path]) -> bool: + """ + Apply key value cache injection to an ONNX model + + :param onnx_model_path: path to the ONNX model to inject + :return: True if successful, False otherwise + """ + onnx_model = onnx.load(onnx_model_path, load_external_data=False) + model_path = os.path.dirname(onnx_model_path) + exporter = KeyValueCacheInjector(model_path=model_path) + exporter.export(onnx_model, onnx_model_path) + return True diff --git a/tests/sparseml/export/test_export_data.py b/tests/sparseml/export/test_export_data.py index 26b6af46b17..fa7dd6682c4 100644 --- a/tests/sparseml/export/test_export_data.py +++ b/tests/sparseml/export/test_export_data.py @@ -19,6 +19,7 @@ import pytest from sparseml.export.export_data import create_data_samples, export_data_sample +from tests.sparseml.export.utils import get_dummy_dataset @pytest.fixture() @@ -72,50 +73,65 @@ def test_export_data_sample(tmp_path, as_tar, dummy_names, dummy_samples): [True, False], ) @pytest.mark.parametrize("num_samples", [0, 1, 5]) -def test_create_data_samples(num_samples, model): +def test_create_data_samples_transformers(num_samples, model): pytest.importorskip("torch", reason="test requires pytorch") - import torch - from torch.utils.data import DataLoader, Dataset - - model = torch.nn.Sequential(torch.nn.Identity()) if model else None + from torch.utils.data import DataLoader - class DummyDataset(Dataset): - def __init__(self, inputs, outputs): - self.data = inputs - self.target = outputs + class Identity(torch.nn.Module): + def __init__(self): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + self.device = self.dummy_param.device - def __len__(self): - return len(self.data) + def forward(self, input_ids, attention_mask): + return dict(input_ids=input_ids, attention_mask=attention_mask) - def __getitem__(self, index): - data_sample = self.data[index] - target_sample = self.target[index] + model = Identity().to("cpu") if model else None - return data_sample, target_sample + data_loader = DataLoader(get_dummy_dataset("transformers"), batch_size=1) - inputs = torch.randn((100, 3, 224, 224)) - labels = torch.randint( - 0, - 10, - ( - 100, - 50, - ), + inputs, outputs, labels = create_data_samples( + data_loader=data_loader, num_samples=num_samples, model=model ) + target_input = next(iter(data_loader)) + target_output = target_input + + assert len(inputs) == num_samples + for input in inputs: + for key, value in input.items(): + assert torch.equal(value, target_input[key]) + assert labels == [] + if model is not None: + assert len(outputs) == num_samples + for output in outputs: + for key, value in output.items(): + assert torch.equal(value, target_output[key][0]) + + +@pytest.mark.parametrize( + "model", + [True, False], +) +@pytest.mark.parametrize("num_samples", [0, 1, 5]) +def test_create_data_samples_image_classification(num_samples, model): + pytest.importorskip("torch", reason="test requires pytorch") - custom_dataset = DummyDataset(inputs, labels) + import torch + from torch.utils.data import DataLoader - data_loader = DataLoader(custom_dataset, batch_size=1) + model = torch.nn.Sequential(torch.nn.Identity()) if model else None + data_loader = DataLoader(get_dummy_dataset("image-classification"), batch_size=1) inputs, outputs, labels = create_data_samples( data_loader=data_loader, num_samples=num_samples, model=model ) - - assert all(tuple(input.shape) == (1, 3, 224, 224) for input in inputs) - assert all(tuple(label.shape) == (1, 50) for label in labels) + target_input, target_label = next(iter(data_loader)) + target_output = target_input + assert all(input.shape == target_input.shape for input in inputs) + assert all(label.shape == target_label.shape for label in labels) assert len(inputs) == num_samples == len(labels) if model is not None: assert len(outputs) == num_samples - assert all(tuple(output.shape) == (1, 3, 224, 224) for output in outputs) + assert all(output.shape == target_output.shape for output in outputs) diff --git a/tests/sparseml/export/test_helpers.py b/tests/sparseml/export/test_helpers.py index 1934004a64e..e1eb6731709 100644 --- a/tests/sparseml/export/test_helpers.py +++ b/tests/sparseml/export/test_helpers.py @@ -74,7 +74,7 @@ def test_create_deployment_folder( create_deployment_folder( source_path=source_path, target_path=target_path, - deployment_directory_files=deployment_directory_list, + deployment_directory_files_mandatory=deployment_directory_list, ) assert ( @@ -83,14 +83,14 @@ def test_create_deployment_folder( ) -def foo(onnx_model): - logging.debug("foo") - return onnx_model +def foo(*args, **kwargs): + logging.debug("foo applied") + return True -def bar(onnx_model): - logging.debug("bar") - return onnx_model +def bar(*args, **kwargs): + logging.debug("bar applied") + return True @pytest.fixture() @@ -133,8 +133,8 @@ def test_apply_optimizations_empty( "target_optimizations, expected_logs, should_raise_error", [ ("none", [], False), - ("all", ["bar", "foo"], False), - (["foo"], ["foo"], False), + ("all", ["bar applied", "foo applied"], False), + (["foo"], ["foo applied"], False), ("error_name", [], True), (["error_name"], [], True), ], @@ -167,4 +167,4 @@ def test_apply_optimizations( available_optimizations=available_optimizations, ) - assert caplog.messages == expected_logs + assert set(expected_logs).issubset(set(caplog.messages)) diff --git a/tests/sparseml/export/transformers/__init__.py b/tests/sparseml/export/transformers/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/export/transformers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/export/transformers/test_generative_transformers.py b/tests/sparseml/export/transformers/test_generative_transformers.py new file mode 100644 index 00000000000..ebbd902c343 --- /dev/null +++ b/tests/sparseml/export/transformers/test_generative_transformers.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +import onnx +import pytest +import torch + +from huggingface_hub import snapshot_download +from sparseml.export.export import export + + +@pytest.mark.parametrize( + "stub, task", + [("roneneldan/TinyStories-1M", "text-generation")], +) +class TestEndToEndExport: + @pytest.fixture() + def setup(self, tmp_path, stub, task): + model_path = tmp_path / "model" + target_path = tmp_path / "target" + + source_path = snapshot_download(stub, local_dir=model_path) + + yield source_path, target_path, task + + shutil.rmtree(tmp_path) + + def test_export_happy_path(self, setup): + source_path, target_path, task = setup + export( + source_path=source_path, + target_path=target_path, + task=task, + ) + assert (target_path / "deployment" / "model.onnx").exists() + # check if kv cache injection has been applied + onnx_model = onnx.load( + str(target_path / "deployment" / "model.onnx"), load_external_data=False + ) + assert any( + inp.name == "past_key_values.0.key" for inp in onnx_model.graph.input + ) + + def test_export_with_sample_data(self, setup): + source_path, target_path, task = setup + + sequence_length = 32 + sample_data = dict( + input_ids=torch.ones((10, sequence_length), dtype=torch.long), + attention_mask=torch.ones((10, sequence_length), dtype=torch.long), + ) + export( + source_path=source_path, + target_path=target_path, + task=task, + sample_data=sample_data, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + @pytest.mark.skipif( + reason="skipping since this functionality needs some more attention" + ) + def test_export_validate_correctness(self, setup): + source_path, target_path, task = setup + + num_samples = 4 + + export( + source_path=source_path, + target_path=target_path, + task=task, + num_export_samples=num_samples, + validate_correctness=True, + ) diff --git a/tests/sparseml/export/transformers/test_transformers.py b/tests/sparseml/export/transformers/test_transformers.py new file mode 100644 index 00000000000..b2959ed1ae4 --- /dev/null +++ b/tests/sparseml/export/transformers/test_transformers.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import os +import shutil + +import numpy as np +import pytest +import torch + +from sparseml.export.export import export +from sparsezoo import Model + + +@pytest.mark.parametrize( + "stub, task", + [ + ("zoo:obert-medium-squad_wikipedia_bookcorpus-pruned95_quantized", "qa"), + ], +) +class TestEndToEndExport: + @pytest.fixture() + def setup(self, tmp_path, stub, task): + model_path = tmp_path / "model" + target_path = tmp_path / "target" + + source_path = Model(stub, model_path).training.path + + yield source_path, target_path, task + + shutil.rmtree(tmp_path) + + def test_export_happy_path(self, setup): + source_path, target_path, task = setup + export( + source_path=source_path, + target_path=target_path, + task=task, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + def test_export_samples(self, setup): + source_path, target_path, task = setup + + num_samples = 4 + + export( + source_path=source_path, + target_path=target_path, + task=task, + num_export_samples=num_samples, + **dict(data_args=dict(dataset_name="squad")), + ) + assert (target_path / "deployment" / "model.onnx").exists() + assert ( + len(os.listdir(os.path.join(target_path, "sample-inputs"))) == num_samples + ) + assert ( + len(os.listdir(os.path.join(target_path, "sample-outputs"))) == num_samples + ) + assert np.load( + glob.glob(os.path.join(target_path, "sample-inputs/*"))[0], + allow_pickle=True, + )["arr_0"] + + def test_export_with_sample_data(self, setup): + source_path, target_path, task = setup + + sequence_length = 32 + sample_data = dict( + input_ids=torch.ones((10, sequence_length), dtype=torch.long), + attention_mask=torch.ones((10, sequence_length), dtype=torch.long), + ) + export( + source_path=source_path, + target_path=target_path, + task=task, + sample_data=sample_data, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + @pytest.mark.skipif(reason="skipping since not implemented") + def test_export_multiple_files(self, setup): + source_path, target_path, task = setup + export( + source_path=source_path, + target_path=target_path, + task=task, + single_graph_file=False, + ) + + @pytest.mark.skipif( + reason="skipping since this functionality needs some more attention" + ) + def test_export_validate_correctness(self, setup): + source_path, target_path, task = setup + + num_samples = 4 + + export( + source_path=source_path, + target_path=target_path, + task=task, + num_export_samples=num_samples, + validate_correctness=True, + **dict(data_args=dict(dataset_name="squad")), + ) diff --git a/tests/sparseml/export/utils.py b/tests/sparseml/export/utils.py new file mode 100644 index 00000000000..52a1347818d --- /dev/null +++ b/tests/sparseml/export/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import torch +from torch.utils.data import Dataset + + +class DummyDatasetTransformers(Dataset): + def __init__(self, inputs: List[Dict]): + self.data = inputs + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + data_sample = self.data[index] + return data_sample + + +class DummyDatasetImageClassification(Dataset): + def __init__(self, inputs: torch.Tensor, labels: torch.Tensor): + self.inputs = inputs + self.labels = labels + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, index): + data_sample = self.inputs[index] + label_sample = self.labels[index] + return data_sample, label_sample + + +def get_dummy_dataset(integration): + if integration == "image-classification": + inputs = torch.randn((100, 3, 224, 224)) + labels = torch.randint( + 0, + 10, + ( + 100, + 50, + ), + ) + return DummyDatasetImageClassification(inputs, labels) + elif integration == "transformers": + input = dict( + input_ids=torch.ones((10, 100), dtype=torch.long), + attention_mask=torch.ones((10, 100), dtype=torch.long), + ) + return DummyDatasetTransformers([input for _ in range(100)]) + else: + raise NotImplementedError( + "Getting dummy dataset for " f"integration {integration} not implemented" + ) diff --git a/tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py b/tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py index c92ab7de304..908f0f6e229 100644 --- a/tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py +++ b/tests/sparseml/pytorch/image_classification/test_integration_helper_functions.py @@ -27,6 +27,7 @@ def test_integration_helper_functions(): assert image_classification.create_model assert image_classification.create_dummy_input assert image_classification.export - assert image_classification.graph_optimizations is None + assert image_classification.apply_optimizations is None assert image_classification.create_data_samples - assert image_classification.deployment_directory_structure == ["model.onnx"] + assert image_classification.deployment_directory_files_mandatory == ["model.onnx"] + assert image_classification.deployment_directory_files_optional is None diff --git a/tests/sparseml/transformers/test_integration_helper_functions.py b/tests/sparseml/transformers/test_integration_helper_functions.py new file mode 100644 index 00000000000..a80503707ab --- /dev/null +++ b/tests/sparseml/transformers/test_integration_helper_functions.py @@ -0,0 +1,67 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from src.sparseml.integration_helper_functions import ( + IntegrationHelperFunctions, + Integrations, +) + + +def test_integration_helper_functions(): + # import needed to register the object on the fly + import sparseml.transformers.integration_helper_functions # noqa F401 + + transformers = IntegrationHelperFunctions.load_from_registry( + Integrations.transformers.value + ) + assert transformers.create_model + assert transformers.create_dummy_input + assert transformers.export + assert transformers.apply_optimizations is None + assert transformers.create_data_samples + assert set(transformers.deployment_directory_files_mandatory) == { + "model.onnx", + "tokenizer_config.json", + "config.json", + } + assert set(transformers.deployment_directory_files_optional) == { + "tokenizer.json", + "tokenizer.model", + } + + +def test_integration_helper_function_text_generation(): + # import needed to register the object on the fly + import sparseml.transformers.integration_helper_functions # noqa F401 + + transformers = IntegrationHelperFunctions.load_from_registry( + Integrations.transformers.value, task="text-generation" + ) + + assert transformers.create_model + assert transformers.create_dummy_input + assert transformers.export + assert transformers.apply_optimizations is not None + assert transformers.create_data_samples + assert set(transformers.deployment_directory_files_mandatory) == { + "model.onnx", + "tokenizer_config.json", + "config.json", + "special_tokens_map.json", + "vocab.json", + "merges.txt", + } + assert set(transformers.deployment_directory_files_optional) == { + "tokenizer.json", + "tokenizer.model", + } diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index e2413f492ae..b8ad806f6e7 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -12,12 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict + import pytest +import torch -from sparseml.transformers.utils.helpers import save_zoo_directory +from huggingface_hub import snapshot_download +from sparseml.transformers.utils.helpers import is_transformer_model, save_zoo_directory from sparsezoo import Model +@pytest.fixture() +def generative_model_path(tmp_path): + return snapshot_download("roneneldan/TinyStories-1M", local_dir=tmp_path) + + +@pytest.fixture() +def model_path(tmp_path): + return Model( + "zoo:mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block_quantized", + tmp_path, + ).training.path + + +@pytest.fixture() +def sequence_length(): + return 384 + + +@pytest.fixture() +def dummy_inputs(): + input_ids = torch.zeros((1, 32), dtype=torch.int64) + attention_mask = torch.ones((1, 32), dtype=torch.int64) + + return OrderedDict( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + +@pytest.mark.parametrize( + "stub", + [ + "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none", # noqa E501 + ], +) +def test_is_transformer_model(tmp_path, stub): + zoo_model = Model(stub, tmp_path) + source_path = zoo_model.training.path + assert is_transformer_model(source_path) + + @pytest.mark.parametrize( "stub", [ diff --git a/tests/sparseml/transformers/utils/test_initializers.py b/tests/sparseml/transformers/utils/test_initializers.py new file mode 100644 index 00000000000..3c27e460baf --- /dev/null +++ b/tests/sparseml/transformers/utils/test_initializers.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +import pytest + +from sparseml.pytorch.utils.helpers import default_device, use_single_gpu +from sparseml.transformers.utils.load_task_dataset import load_task_dataset +from sparsezoo import Model +from src.sparseml.transformers.utils.initializers import ( + initialize_config, + initialize_model, + initialize_tokenizer, + initialize_trainer, +) + + +@pytest.mark.parametrize( + "stub, task, data_args", + [ + ( + "zoo:obert-medium-squad_wikipedia_bookcorpus-pruned95_quantized", + "qa", + dict(dataset_name="squad"), + ), + ( + "zoo:distilbert-qqp_wikipedia_bookcorpus-pruned80.4block_quantized", + "text-classification", + None, + ), + ], + scope="class", +) +@pytest.mark.parametrize("device", ["auto", "cpu", None], scope="class") +class TestInitializeModelFlow: + @pytest.fixture() + def setup(self, tmp_path, stub, task, data_args, device): + self.model_path = Model(stub, tmp_path).training.path + self.sequence_length = 384 + self.task = task + self.data_args = data_args + + # process device argument + device = default_device() if device == "auto" else device + # if multiple gpus available use the first one + if not (device is None or device == "cpu"): + device = use_single_gpu(device) + self.device = device + yield + shutil.rmtree(tmp_path) + + def test_initialize_config(self, setup): + assert initialize_config(model_path=self.model_path, trust_remote_code=True) + + def test_initialize_tokenizer(self, setup): + tokenizer = initialize_tokenizer( + self.model_path, self.sequence_length, self.task + ) + assert ( + tokenizer.padding_side == "right" + if self.task != "text-generation" + else "left" + ) + assert tokenizer.model_max_length == self.sequence_length + + def test_initialize_model(self, setup): + model = initialize_model( + model_path=self.model_path, + task=self.task, + device=self.device, + config=initialize_config( + model_path=self.model_path, trust_remote_code=True + ), + ) + assert model + self._test_model_device(model) + + def test_initialize_trainer(self, setup): + if not self.data_args: + pytest.skip("To run this test, please provide valid data_args") + config = initialize_config(model_path=self.model_path, trust_remote_code=True) + model = initialize_model( + model_path=self.model_path, + task=self.task, + device=self.device, + config=config, + ) + tokenizer = initialize_tokenizer( + self.model_path, self.sequence_length, self.task + ) + dataset = load_task_dataset( + task=self.task, + tokenizer=tokenizer, + data_args=self.data_args, + model=model, + config=config, + ) + validation_dataset = dataset.get("validation") + + trainer = initialize_trainer( + model=model, + model_path=self.model_path, + validation_dataset=validation_dataset, + ) + # assert that trainer is not messing up with model's location + self._test_model_device(model) + assert trainer.get_eval_dataloader() + + def test_initialize_trainer_no_validation_dataset(self, setup): + config = initialize_config(model_path=self.model_path, trust_remote_code=True) + tokenizer = initialize_tokenizer( + self.model_path, self.sequence_length, self.task + ) + model = initialize_model( + model_path=self.model_path, + task=self.task, + config=config, + ) + trainer = initialize_trainer( + model=model, model_path=self.model_path, validation_dataset=None + ) + self._test_model_device(model) + assert trainer.eval_dataset is None + assert trainer._get_fake_dataloader(num_samples=10, tokenizer=tokenizer) + + def _test_model_device(self, model): + if model.device.type == "cuda": + assert self.device.startswith("cuda") + else: + assert self.device is None or self.device == "cpu"