diff --git a/src/sparseml/export/export.py b/src/sparseml/export/export.py index 713d2ca0ef6..b016cd6f4ce 100644 --- a/src/sparseml/export/export.py +++ b/src/sparseml/export/export.py @@ -28,7 +28,7 @@ from sparseml.export.validators import validate_correctness as validate_correctness_ from sparseml.export.validators import validate_structure as validate_structure_ from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET -from sparseml.pytorch.utils.helpers import default_device +from sparseml.pytorch.utils.helpers import default_device, use_single_gpu from src.sparseml.integration_helper_functions import ( IntegrationHelperFunctions, resolve_integration, @@ -117,6 +117,7 @@ def export( # choose the appropriate device device = default_device() if device == "auto" else device + device = use_single_gpu(device) if "cuda" in device else device # assert the valid deployment target if deployment_target not in AVAILABLE_DEPLOYMENT_TARGETS: diff --git a/src/sparseml/integration_helper_functions.py b/src/sparseml/integration_helper_functions.py index fee55b87fb6..b9301ba2417 100644 --- a/src/sparseml/integration_helper_functions.py +++ b/src/sparseml/integration_helper_functions.py @@ -32,10 +32,14 @@ class Integrations(Enum): """ image_classification = "image-classification" + transformers = "transformers" + transformers_generative = "transformers-generative" def resolve_integration( - source_path: Union[Path, str], integration: Optional[str] = None + source_path: Union[Path, str], + integration: Optional[str] = None, + task: Optional[str] = None, ) -> str: """ Resolve the integration to use. @@ -47,24 +51,46 @@ def resolve_integration( :param source_path: The path to the PyTorch model to export. :param integration: Optional name of the integration to use. If not provided, will attempt to infer it from the source_path. + :param task: Optional name of the task to use. :return: The name of the integration to use for exporting the model. """ if integration is not None: integration = integration.replace("_", "-") + if task is not None: + task = task.replace("_", "-") + from sparseml.pytorch.image_classification.utils.helpers import ( is_image_classification_model, ) + from sparseml.transformers.utils.helpers import ( + TaskNames, + is_transformer_generative_model, + is_transformer_model, + ) if ( integration == Integrations.image_classification.value or is_image_classification_model(source_path) ): - # import to register the image_classification integration helper functions import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401 return Integrations.image_classification.value + + elif task in TaskNames.text_generation.value or is_transformer_generative_model( + source_path + ): + import sparseml.transformers.integration_helper_functions_generative # noqa F401 + + return Integrations.transformers_generative.value + + elif integration == Integrations.transformers.value or is_transformer_model( + source_path + ): + import sparseml.transformers.integration_helper_functions # noqa F401 + + return Integrations.transformers.value else: raise ValueError( f"Could not infer integration from source_path:\n{source_path}\n" @@ -81,36 +107,23 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): integration. """ - create_model: Optional[ - Callable[ - Tuple[Union[str, Path], Optional[int], str, Optional[Dict[str, Any]]], - Tuple[ - "torch.nn.Module", # noqa F821 - Optional["torch.utils.data.Dataloader"], # noqa F821 - ], - ] + create_model: Callable[ + [Union[str, Path], ...], + Tuple[ + "torch.nn.Module", # noqa F821 + Optional[Dict[str, Any]], + ], ] = Field( description="A function that takes: " "- a source path to a PyTorch model " - "- a batch size " - "- a device name " - "- (optionally) a dictionary of additional arguments" + "- (optionally) additional arguments" "and returns: " "- a (sparse) PyTorch model " - "- (optionally) a data loader " + "- (optionally) a dictionary of auxiliary items" ) - create_dummy_input: Optional[ - Callable[ - Tuple[ - Optional["torch.utils.data.Dataloader"], # noqa F821 - Optional[Dict[str, Any]], - ], - "torch.Tensor", # noqa F821 - ] - ] = Field( + create_dummy_input: Callable[..., "torch.Tensor"] = Field( # noqa F821 description="A function that takes: " - "- (optionally) a data loader " - "- (optionally) a dictionary of additional arguments" + "- appropriate arguments " "and returns: " "- a dummy input for the model (a torch.Tensor) " ) @@ -131,28 +144,29 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): ) create_data_samples: Callable[ - Tuple[ - Optional["torch.nn.Module"], # noqa F821 - "torch.utils.data.DataLoader", # noqa F821 - int, - ], + Tuple[Optional["torch.nn.Module"], int, Optional[Dict[str, Any]]], # noqa F821 Tuple[ List["torch.Tensor"], # noqa F821 Optional[List["torch.Tensor"]], # noqa F821 - List["torch.Tensor"], # noqa F821 + Optional[List["torch.Tensor"]], # noqa F821 ], ] = Field( default=create_data_samples_, description="A function that takes: " " - (optionally) a (sparse) PyTorch model " - " - a data loader " " - the number of samples to generate " + " - (optionally) additional auxiliary items " "and returns: " - " - the inputs, labels and (optionally) outputs as torch tensors ", + " - the inputs, (optionally) labels and (optionally) outputs as torch tensors ", ) - deployment_directory_structure: List[str] = Field( + deployment_directory_files_mandatory: List[str] = Field( description="A list that describes the " - "expected files of the deployment directory", + "mandatory expected files of the deployment directory", default=["model.onnx"], ) + + deployment_directory_files_optional: Optional[List[str]] = Field( + description="A list that describes the " + "optional expected files of the deployment directory", + ) diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index f2f5eccc7d0..af23894e5dd 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -137,6 +137,15 @@ def default_device() -> str: return "cuda:{}".format(",".join(device_ids)) +def use_single_gpu(device: str) -> str: + """ + return: the first gpu in the device string if multiple are available + """ + if "cuda" not in device: + raise ValueError("use_single_gpu should only be called on cuda devices") + return device.split(",")[0] + + def device_of(inputs: Any): if isinstance(inputs, Tensor): return inputs.device @@ -538,7 +547,8 @@ def _tensors_export_batch( return if isinstance(tensors, Iterable): - for index, tens in enumerate(zip(*tensors)): + # TODO: I am breaking something here? - dbogunowicz + for index, tens in enumerate(zip(tensors)): exported_paths.append( tensor_export( tens, export_dir, "{}-{:04d}".format(name_prefix, counter + index) diff --git a/src/sparseml/transformers/integration_helper_functions.py b/src/sparseml/transformers/integration_helper_functions.py new file mode 100644 index 00000000000..e32354b6b0d --- /dev/null +++ b/src/sparseml/transformers/integration_helper_functions.py @@ -0,0 +1,168 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from pydantic import Field +from transformers import AutoTokenizer + +from sparseml.transformers.sparsification.trainer import Trainer +from sparseml.transformers.utils.helpers import ( + MANDATORY_DEPLOYMENT_FILES, + OPTIONAL_DEPLOYMENT_FILES, +) +from sparseml.transformers.utils.load_task_dataset import load_task_dataset +from src.sparseml.export.export_data import create_data_samples as create_data_samples_ +from src.sparseml.integration_helper_functions import ( + IntegrationHelperFunctions, + Integrations, +) +from src.sparseml.transformers.utils.initializers import ( + _parse_data_args, + initialize_config, + initialize_model, + initialize_tokenizer, + initialize_trainer, + resolve_sequence_length, +) + + +_LOGGER = logging.getLogger(__name__) + + +def create_model( + source_path: Union[Path, str], + device: Optional[str] = None, + task: Optional[str] = None, + **kwargs, +) -> Tuple[torch.nn.Module, Dict[str, Any]]: + """ + A contract to create a model and optional dictionary of + auxiliary items related to the model + + :param source_path: The path to the model + :param device: The device to use for the model and dataloader instantiation + :param task: The task to use for the model and dataloader instantiation + + :return: A tuple of the + - torch model + - (optionally) a dictionary of auxiliary items + """ + config_args = kwargs.get("config_args", {}) + sequence_length = kwargs.get("sequence_length", None) + data_args = kwargs.get("data_args", {}) + trust_remote_code = kwargs.get("trust_remote_code", False) + + if task is None: + raise ValueError("To create a transformer model, a task must be specified") + + if not trust_remote_code: + _LOGGER.warning( + "trust_remote_code is set to False. It is possible, " + "that the model will not be loaded correctly." + ) + + config = initialize_config(source_path, trust_remote_code, **config_args) + sequence_length = sequence_length or resolve_sequence_length(config) + tokenizer = initialize_tokenizer(source_path, sequence_length, task) + model = initialize_model( + model_path=source_path, + task=task, + config=config, + trust_remote_code=trust_remote_code, + device=device, + ) + + data_args = _parse_data_args(data_args) + + if data_args: + dataset = load_task_dataset( + task=task, + tokenizer=tokenizer, + data_args=data_args, + model=model, + config=config, + ) + validation_dataset = dataset.get("validation") + + else: + validation_dataset = None + + model.train() + trainer = initialize_trainer(model, source_path, validation_dataset) + model.eval() + + return model, dict( + trainer=trainer, + tokenizer=tokenizer, + input_names=list(next(trainer._get_fake_dataloader(1, tokenizer)).keys()), + ) + + +def create_dummy_input( + trainer: Optional[Trainer] = None, + tokenizer: Optional[AutoTokenizer] = None, + **kwargs, +) -> torch.Tensor: + if trainer.eval_dataset is not None: + data_loader = trainer.get_eval_dataloader() + else: + if not tokenizer: + raise ValueError( + "Tokenizer is needed to generate " + "fake sample inputs when the trainer is " + "not initialized with an eval dataset" + ) + data_loader = trainer._get_fake_dataloader(num_samples=1, tokenizer=tokenizer) + return next(iter(data_loader)) + + +def create_data_samples( + num_samples: int, + trainer: Trainer, + model: Optional["torch.nn.Module"] = None, + **kwargs, +): + if kwargs.get("batch_size"): + _LOGGER.info( + "For exporting samples for transformers integration," + "batch size is ignored (equal to 1)" + ) + if trainer.eval_dataset is None: + raise ValueError( + "Attempting to create data samples without an eval dataloader. " + "Initialize a trainer with an eval dataset" + ) + + return create_data_samples_( + data_loader=trainer.get_eval_dataloader(), model=model, num_samples=num_samples + ) + + +@IntegrationHelperFunctions.register(name=Integrations.transformers.value) +class Transformers(IntegrationHelperFunctions): + create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field( + default=create_model + ) + create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input) + create_data_samples: Callable = Field(create_data_samples) + deployment_directory_files_mandatory: List[str] = Field( + default=list(MANDATORY_DEPLOYMENT_FILES) + ) + deployment_directory_files_optional: List[str] = Field( + default=list(OPTIONAL_DEPLOYMENT_FILES) + ) diff --git a/src/sparseml/transformers/integration_helper_functions_generative.py b/src/sparseml/transformers/integration_helper_functions_generative.py new file mode 100644 index 00000000000..2c17f47395a --- /dev/null +++ b/src/sparseml/transformers/integration_helper_functions_generative.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Dict, List + +from pydantic import Field + +from sparseml.transformers.integration_helper_functions import Transformers +from sparseml.transformers.utils.helpers import ( + MANDATORY_DEPLOYMENT_FILES, + NLG_TOKENIZER_FILES, +) +from sparseml.transformers.utils.optimizations import apply_kv_cache_injection +from src.sparseml.integration_helper_functions import ( + IntegrationHelperFunctions, + Integrations, +) + + +generative_transformers_graph_optimizations = { + "kv_cache_injection": apply_kv_cache_injection +} + + +@IntegrationHelperFunctions.register(name=Integrations.transformers_generative.value) +class GenerativeTransformers(Transformers): + graph_optimizations: Dict[str, Callable] = Field( + default=generative_transformers_graph_optimizations + ) + deployment_directory_files_mandatory: List[str] = Field( + default=list(MANDATORY_DEPLOYMENT_FILES.union(NLG_TOKENIZER_FILES)) + ) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index f28f04cc560..96e2767e7fc 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -22,12 +22,10 @@ import math import os import warnings -from contextlib import suppress from dataclasses import asdict from typing import Any, Dict, List, Optional, Tuple, Union import datasets -import numpy import torch from torch import distributed as dist from torch.nn import Module @@ -499,90 +497,6 @@ def log_model_sparsification(self): f"all sparsification info: {sparsification_info}" ) - def save_sample_inputs_outputs( - self, - num_samples_to_export: int = 100, - output_dir: Optional[str] = None, - tokenizer: Optional[Any] = None, - ): - """ - Save sample inputs/outputs/labels in save_dir as .npz arrays - - :param num_samples_to_export: Number of samples to export. - Defaults to 100 - :param output_dir: The directory to store sample inputs and outputs in - :param tokenizer: if eval and train dataset cannot be generated, then - the tokenizer is used to generate fake inputs - """ - num_samples = 0 - - if output_dir is None: - output_dir = ( - self.args.output_dir if hasattr(self.args, "output_dir") else "" - ) - - sample_in_dir = os.path.join(output_dir, "sample-inputs") - sample_out_dir = os.path.join(output_dir, "sample-outputs") - - os.makedirs(sample_in_dir, exist_ok=True) - os.makedirs(sample_out_dir, exist_ok=True) - device = self.model.device - - dataloader = None - try: - dataloader = self.get_eval_dataloader() - except Exception: - with suppress(ValueError): - dataloader = self.get_train_dataloader() - - if not dataloader and not tokenizer: - raise ValueError( - "tokenizer is needed to generate fake sample inputs when Trainer is " - "not initialized with a train or eval dataset" - ) - if dataloader is None: - # we have the tokenizer so use it - dataloader = self._get_fake_dataloader( - num_samples=num_samples_to_export, tokenizer=tokenizer - ) - - _LOGGER.info( - f"Exporting {num_samples_to_export} samples to " - f"{os.path.abspath(output_dir)}" - ) - for _, sample_batch in enumerate(dataloader): - sample_batch.pop("labels", None) - input_names = list(sample_batch.keys()) - - for input_vals in zip(*sample_batch.values()): - input_feed = {k: v.to("cpu") for k, v in zip(input_names, input_vals)} - model_inputs = { - k: input_feed[k].to(device).reshape(1, -1) for k in input_feed - } - output_vals = self.model(**model_inputs) - output_dict = { - name: torch.squeeze(val).detach().to("cpu") - for name, val in output_vals.items() - } - file_idx = f"{num_samples}".zfill(4) - - sample_input_filename = os.path.join( - f"{sample_in_dir}", f"inp-{file_idx}.npz" - ) - numpy.savez(sample_input_filename, **input_feed) - - sample_output_filename = os.path.join( - f"{sample_out_dir}", f"out-{file_idx}.npz" - ) - numpy.savez(sample_output_filename, **output_dict) - num_samples += 1 - - if num_samples >= num_samples_to_export: - break - if num_samples >= num_samples_to_export: - break - _LOGGER.info(f"Exported {num_samples_to_export} samples to {output_dir}") - def _extract_metadata( self, metadata_args: List[str], @@ -789,7 +703,7 @@ def _get_fake_dataloader( num_samples: int, tokenizer: "PreTrainedTokenizerBase", # noqa: F821 ): - # Rearrange inputs' keys to match those defined by model foward func, which + # Rearrange inputs' keys to match those defined by model forward func, which # seem to define how the order of inputs is determined in the exported model forward_args_spec = inspect.getfullargspec(self.model.__class__.forward) synthetic_input = self._get_fake_input( diff --git a/src/sparseml/transformers/utils/__init__.py b/src/sparseml/transformers/utils/__init__.py index 4647a435f12..16bfb5062e4 100644 --- a/src/sparseml/transformers/utils/__init__.py +++ b/src/sparseml/transformers/utils/__init__.py @@ -17,7 +17,6 @@ """ # flake8: noqa - from .helpers import * from .metrics import * from .model import * diff --git a/src/sparseml/transformers/utils/helpers.py b/src/sparseml/transformers/utils/helpers.py index c796f5a39ac..b80d592ba8d 100644 --- a/src/sparseml/transformers/utils/helpers.py +++ b/src/sparseml/transformers/utils/helpers.py @@ -16,22 +16,109 @@ Helper variables and functions for integrating SparseML with huggingface/transformers flows """ + import logging import os -from typing import Optional +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union +import torch from transformers.trainer_utils import get_last_checkpoint +from sparseml.export.helpers import ONNX_MODEL_NAME from sparsezoo import setup_model _LOGGER: logging.Logger = logging.getLogger(__name__) -__all__ = ["RECIPE_NAME", "save_zoo_directory", "detect_last_checkpoint"] +__all__ = [ + "RECIPE_NAME", + "save_zoo_directory", + "detect_last_checkpoint", + "TaskNames", + "is_transformer_model", + "is_transformer_generative_model", + "run_transformers_inference", +] + + +class TaskNames(Enum): + mlm = {"masked-language-modeling", "mlm"} + qa = {"question-answering", "qa"} + token_classification = {"token-classification", "ner"} + text_classification = { + "text-classification", + "sentiment-analysis", + "sequence-classification", + "glue", + } + text_generation = {"text-generation"} RECIPE_NAME = "recipe.yaml" +MANDATORY_DEPLOYMENT_FILES = { + ONNX_MODEL_NAME, + "tokenizer_config.json", + "config.json", +} +NLG_TOKENIZER_FILES = {"special_tokens_map.json", "vocab.json", "merges.txt"} +OPTIONAL_DEPLOYMENT_FILES = {"tokenizer.json", "tokenizer.model"} + + +# TODO: Move this this functionality to export module once merged +def run_transformers_inference( + inputs: Dict[str, Any], model: Optional[torch.nn.Module] = None +) -> Tuple[Dict[str, Any], Any, Dict[str, Any]]: + """ + Run inference on a transformers model and return the inputs, labels and outputs + + :param inputs: The inputs to run inference on + :param model: The model to run inference on (optional) + + :return: The inputs, labels and outputs + """ + label = None # transformers in general have no labels + if model is None: + inputs = {key: value.to("cpu") for key, value in inputs.items()} + return inputs, label, None + + inputs = {key: value.to(model.device) for key, value in inputs.items()} + output_vals = model(**inputs) + inputs = {key: value.to("cpu") for key, value in inputs.items()} + output = { + name: torch.squeeze(val).detach().to("cpu") for name, val in output_vals.items() + } + return inputs, label, output + + +def is_transformer_model(source_path: Union[Path, str]) -> bool: + """ + :param source_path: The path to the model + :return: Whether the model is a transformers model or not + """ + # make sure that the path is a directory and contains + # the EXPECTED_TRANSFORMER_FILES + if not os.path.isdir(source_path): + raise ValueError(f"Path {source_path} is not a valid directory") + expected_files = MANDATORY_DEPLOYMENT_FILES.difference({ONNX_MODEL_NAME}) + return expected_files.issubset(os.listdir(source_path)) + + +def is_transformer_generative_model(source_path: Union[Path, str]) -> bool: + """ + :param source_path: The path to the model + :return: Whether the model is a transformers model or not + """ + # make sure that the path is a directory and contains + # the EXPECTED_TRANSFORMER_FILES + if not os.path.isdir(source_path): + raise ValueError(f"Path {source_path} is not a valid directory") + expected_files = MANDATORY_DEPLOYMENT_FILES.union(NLG_TOKENIZER_FILES).difference( + {ONNX_MODEL_NAME} + ) + return expected_files.issubset(os.listdir(source_path)) def save_zoo_directory( diff --git a/src/sparseml/transformers/utils/initializers.py b/src/sparseml/transformers/utils/initializers.py new file mode 100644 index 00000000000..08e69955877 --- /dev/null +++ b/src/sparseml/transformers/utils/initializers.py @@ -0,0 +1,195 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Functionality for initializing a transformer model from a given path +""" + +import logging +import math +import os +from pathlib import Path +from typing import Any, Optional, Union + +from transformers import AutoConfig, AutoModel, AutoTokenizer, TrainingArguments + +from sparseml.optim import parse_recipe_variables +from sparseml.transformers.sparsification import Trainer +from sparseml.transformers.utils.helpers import TaskNames +from sparseml.transformers.utils.load_task_model import load_task_model + + +__all__ = [ + "initialize_model", + "initialize_tokenizer", + "initialize_trainer", + "initialize_config", + "resolve_sequence_length", +] + +_LOGGER = logging.getLogger(__name__) + + +def initialize_config( + model_path: Union[str, Path], trust_remote_code: bool = False, **config_args +) -> AutoConfig: + """ + Initialize a config from a given path + + :param model_path: the path to the model to load + :param trust_remote_code: True to trust remote code when loading the model, + False otherwise + :param config_args: additional arguments to pass to the config + :return: the loaded config + """ + config = AutoConfig.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + **config_args, + ) + return config + + +def initialize_tokenizer( + model_path: Union[str, Path], sequence_length: int, task: str +) -> AutoTokenizer: + """ + Initialize a tokenizer from a given path + + :param model_path: the path to the model to load + :param sequence_length: the sequence length to use for the tokenizer + :param task: the task to load the tokenizer for + :return: the loaded tokenizer + """ + + tokenizer = AutoTokenizer.from_pretrained( + model_path, model_max_length=sequence_length + ) + if task in TaskNames.text_generation.value: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def initialize_model( + model_path: Union[str, Path], + task: str, + config: AutoConfig, + trust_remote_code: bool = False, + device: Optional[str] = None, +) -> AutoModel: + """ + Initialize a model from a given path + + :param model_path: the path to the model to load + :param task: the task to load the model for + :param config: the config to use for the model + :param trust_remote_code: True to trust remote code when loading the model, + False otherwise + :param device: the device to load the model on. If None, will load on CPU + :return: the loaded model + """ + model = load_task_model( + task=task, + model_path=model_path, + config=config, + trust_remote_code=trust_remote_code, + ) + if device: + model.to(device) + return model + + +def initialize_trainer( + model: AutoModel, + model_path: Union[str, Path], + validation_dataset: Optional[Any] = None, +) -> Trainer: + """ + Initialize a trainer. This will apply the structure dictated by + any of the recipes stored in the model_path + + :param model: the model to initialize the trainer with + :param model_path: the path to the model to load + :param validation_dataset: the validation dataset to use for the trainer + :return: the initialized trainer + """ + + training_args = TrainingArguments(output_dir=os.path.dirname(model_path)) + + trainer = Trainer( + model=model, + args=training_args, + model_state_path=model_path, + eval_dataset=validation_dataset, + recipe=None, + recipe_args=None, + teacher=None, + ) + applied = trainer.apply_manager(epoch=math.inf, checkpoint=None) + + if not applied: + _LOGGER.warning( + f"No recipes were applied for {model_path}, " + "check to make sure recipe(s) are stored in the model_path" + ) + else: + trainer.finalize_manager() + num_stages = 0 + if trainer.manager: + num_stages += trainer.manager.num_stages() + if trainer.arch_manager: + num_stages += trainer.arch_manager.num_stages() + + msg = ( + "an unstaged recipe" + if num_stages == 1 + else f"a staged recipe with {num_stages} stages" + ) + _LOGGER.info(f"Applied {msg} to the model at {model_path}") + + return trainer + + +def resolve_sequence_length(config: AutoConfig) -> int: + """ + Resolve the sequence length from the config + + :param config: the config to resolve the sequence length from + :return: the sequence length + """ + if hasattr(config, "max_position_embeddings"): + sequence_length = config.max_position_embeddings + + elif hasattr(config, "max_seq_len"): + sequence_length = config.max_seq_len + else: + raise ValueError( + "Could not infer a default sequence length " + "from the HF transformers config. Please specify " + "the sequence length with --sequence_length" + ) + _LOGGER.info( + f"Using default sequence length of {sequence_length} " + "(inferred from HF transformers config) " + ) + return sequence_length + + +def _parse_data_args(data_args): + try: + return parse_recipe_variables(data_args) + except ValueError as parse_error: + message = str(parse_error).replace("recipe_args", "data_args") + if "recipe variables" in message: + message = message.replace("recipe variables", "data_args") + raise ValueError(message) diff --git a/src/sparseml/transformers/utils/load_task_dataset.py b/src/sparseml/transformers/utils/load_task_dataset.py new file mode 100644 index 00000000000..cb718b02696 --- /dev/null +++ b/src/sparseml/transformers/utils/load_task_dataset.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +from torch.nn import Module +from transformers import AutoConfig, AutoTokenizer + +from sparseml.transformers.utils.helpers import TaskNames + + +__all__ = ["load_task_dataset"] + + +def load_task_dataset( + task: str, + tokenizer: AutoTokenizer, + data_args: Dict[str, Any], + model: Module, + config: Optional[AutoConfig] = None, +) -> Any: + """ + + Load a dataset for a given task. + + :param task: the task a dataset being loaded for + :param tokenizer: the tokenizer to use for the dataset + :param data_args: additional data args used to create a `DataTrainingArguments` + instance for fetching the dataset + :param model: the model to use for the dataset + :param config: the config to use for the dataset + :return: the dataset for the given task + """ + + if task in TaskNames.mlm.value: + from sparseml.transformers.masked_language_modeling import ( + DataTrainingArguments, + get_tokenized_mlm_dataset, + ) + + data_training_args = DataTrainingArguments(**data_args) + return get_tokenized_mlm_dataset( + data_args=data_training_args, tokenizer=tokenizer + ) + + if task in TaskNames.qa.value: + from sparseml.transformers.question_answering import ( + DataTrainingArguments, + get_tokenized_qa_dataset, + ) + + data_training_args = DataTrainingArguments(**data_args) + return get_tokenized_qa_dataset( + data_args=data_training_args, tokenizer=tokenizer + ) + + if task in TaskNames.token_classification.value: + from sparseml.transformers.token_classification import ( + DataTrainingArguments, + get_tokenized_token_classification_dataset, + ) + + data_training_args = DataTrainingArguments(**data_args) + return get_tokenized_token_classification_dataset( + data_args=data_training_args, tokenizer=tokenizer, model=model or config + ) + + if task in TaskNames.text_classification.value: + from sparseml.transformers.text_classification import ( + DataTrainingArguments, + get_tokenized_text_classification_dataset, + ) + + data_training_args = DataTrainingArguments(**data_args) + + return get_tokenized_text_classification_dataset( + data_args=data_training_args, + tokenizer=tokenizer, + model=model, + config=config, + ) + + raise ValueError(f"unrecognized task given of {task}") diff --git a/src/sparseml/transformers/utils/load_task_model.py b/src/sparseml/transformers/utils/load_task_model.py new file mode 100644 index 00000000000..cb417a51960 --- /dev/null +++ b/src/sparseml/transformers/utils/load_task_model.py @@ -0,0 +1,69 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from torch.nn import Module + +from sparseml.transformers.utils.helpers import TaskNames +from sparseml.transformers.utils.model import SparseAutoModel + + +__all__ = ["load_task_model"] + + +def load_task_model( + task: str, model_path: str, config: Any, trust_remote_code: bool = False +) -> Module: + if task in TaskNames.mlm.value: + return SparseAutoModel.masked_language_modeling_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + if task in TaskNames.qa.value: + return SparseAutoModel.question_answering_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + if task in TaskNames.text_classification.value: + return SparseAutoModel.text_classification_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + if task in TaskNames.token_classification.value: + return SparseAutoModel.token_classification_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + if task in TaskNames.text_generation.value: + return SparseAutoModel.text_generation_from_pretrained( + model_name_or_path=model_path, + config=config, + model_type="model", + trust_remote_code=trust_remote_code, + ) + + raise ValueError(f"unrecognized task given of {task}") diff --git a/src/sparseml/transformers/utils/optimizations.py b/src/sparseml/transformers/utils/optimizations.py new file mode 100644 index 00000000000..20b58775b10 --- /dev/null +++ b/src/sparseml/transformers/utils/optimizations.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from pathlib import Path +from typing import Union + +import onnx + +from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector + + +__all__ = ["apply_kv_cache_injection"] + +_LOGGER = logging.getLogger(__name__) + + +def apply_kv_cache_injection(onnx_model_path: Union[str, Path]) -> bool: + """ + Apply key value cache injection to an ONNX model + + :param onnx_model_path: path to the ONNX model to inject + :return: True if successful, False otherwise + """ + onnx_model = onnx.load(onnx_model_path, load_external_data=False) + model_path = os.path.dirname(onnx_model_path) + exporter = KeyValueCacheInjector(model_path=model_path) + exporter.export(onnx_model, onnx_model_path) + return True diff --git a/tests/sparseml/export/transformers/__init__.py b/tests/sparseml/export/transformers/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/export/transformers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/export/transformers/test_generative_transformers.py b/tests/sparseml/export/transformers/test_generative_transformers.py new file mode 100644 index 00000000000..8a014360fee --- /dev/null +++ b/tests/sparseml/export/transformers/test_generative_transformers.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +import pytest +import torch + +from huggingface_hub import snapshot_download +from sparseml.export.export import export + + +@pytest.mark.parametrize( + "stub, task", + [("roneneldan/TinyStories-1M", "text-generation")], +) +class TestEndToEndExport: + @pytest.fixture() + def setup(self, tmp_path, stub, task): + model_path = tmp_path / "model" + target_path = tmp_path / "target" + + source_path = snapshot_download(stub, local_dir=model_path) + + yield source_path, target_path, task + + shutil.rmtree(tmp_path) + + def test_export_happy_path(self, setup): + source_path, target_path, task = setup + export( + source_path=source_path, + target_path=target_path, + task=task, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + def test_export_with_sample_data(self, setup): + source_path, target_path, task = setup + + sequence_length = 32 + sample_data = dict( + input_ids=torch.ones((10, sequence_length), dtype=torch.long), + attention_mask=torch.ones((10, sequence_length), dtype=torch.long), + ) + export( + source_path=source_path, + target_path=target_path, + task=task, + sample_data=sample_data, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + @pytest.mark.skipif( + reason="skipping since this functionality needs some more attention" + ) + def test_export_validate_correctness(self, setup): + source_path, target_path, task = setup + + num_samples = 4 + + export( + source_path=source_path, + target_path=target_path, + task=task, + num_export_samples=num_samples, + validate_correctness=True, + ) diff --git a/tests/sparseml/export/transformers/test_transformers.py b/tests/sparseml/export/transformers/test_transformers.py new file mode 100644 index 00000000000..b2959ed1ae4 --- /dev/null +++ b/tests/sparseml/export/transformers/test_transformers.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import os +import shutil + +import numpy as np +import pytest +import torch + +from sparseml.export.export import export +from sparsezoo import Model + + +@pytest.mark.parametrize( + "stub, task", + [ + ("zoo:obert-medium-squad_wikipedia_bookcorpus-pruned95_quantized", "qa"), + ], +) +class TestEndToEndExport: + @pytest.fixture() + def setup(self, tmp_path, stub, task): + model_path = tmp_path / "model" + target_path = tmp_path / "target" + + source_path = Model(stub, model_path).training.path + + yield source_path, target_path, task + + shutil.rmtree(tmp_path) + + def test_export_happy_path(self, setup): + source_path, target_path, task = setup + export( + source_path=source_path, + target_path=target_path, + task=task, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + def test_export_samples(self, setup): + source_path, target_path, task = setup + + num_samples = 4 + + export( + source_path=source_path, + target_path=target_path, + task=task, + num_export_samples=num_samples, + **dict(data_args=dict(dataset_name="squad")), + ) + assert (target_path / "deployment" / "model.onnx").exists() + assert ( + len(os.listdir(os.path.join(target_path, "sample-inputs"))) == num_samples + ) + assert ( + len(os.listdir(os.path.join(target_path, "sample-outputs"))) == num_samples + ) + assert np.load( + glob.glob(os.path.join(target_path, "sample-inputs/*"))[0], + allow_pickle=True, + )["arr_0"] + + def test_export_with_sample_data(self, setup): + source_path, target_path, task = setup + + sequence_length = 32 + sample_data = dict( + input_ids=torch.ones((10, sequence_length), dtype=torch.long), + attention_mask=torch.ones((10, sequence_length), dtype=torch.long), + ) + export( + source_path=source_path, + target_path=target_path, + task=task, + sample_data=sample_data, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + @pytest.mark.skipif(reason="skipping since not implemented") + def test_export_multiple_files(self, setup): + source_path, target_path, task = setup + export( + source_path=source_path, + target_path=target_path, + task=task, + single_graph_file=False, + ) + + @pytest.mark.skipif( + reason="skipping since this functionality needs some more attention" + ) + def test_export_validate_correctness(self, setup): + source_path, target_path, task = setup + + num_samples = 4 + + export( + source_path=source_path, + target_path=target_path, + task=task, + num_export_samples=num_samples, + validate_correctness=True, + **dict(data_args=dict(dataset_name="squad")), + ) diff --git a/tests/sparseml/transformers/test_integration_helper_functions.py b/tests/sparseml/transformers/test_integration_helper_functions.py new file mode 100644 index 00000000000..13e60a040a1 --- /dev/null +++ b/tests/sparseml/transformers/test_integration_helper_functions.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from src.sparseml.integration_helper_functions import ( + IntegrationHelperFunctions, + Integrations, +) + + +def test_integration_helper_functions(): + # import needed to register the object on the fly + import sparseml.transformers.integration_helper_functions # noqa F401 + + transformers = IntegrationHelperFunctions.load_from_registry( + Integrations.transformers.value + ) + assert transformers.create_model + assert transformers.create_dummy_input + assert transformers.export + assert transformers.graph_optimizations is None + assert transformers.create_data_samples + assert set(transformers.deployment_directory_files_mandatory) == { + "model.onnx", + "tokenizer_config.json", + "config.json", + } + assert set(transformers.deployment_directory_files_optional) == { + "tokenizer.json", + "tokenizer.model", + } diff --git a/tests/sparseml/transformers/test_integration_helper_functions_generative.py b/tests/sparseml/transformers/test_integration_helper_functions_generative.py new file mode 100644 index 00000000000..7a1c77f7066 --- /dev/null +++ b/tests/sparseml/transformers/test_integration_helper_functions_generative.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from src.sparseml.integration_helper_functions import ( + IntegrationHelperFunctions, + Integrations, +) + + +def test_integration_helper_functions(): + # import needed to register the object on the fly + import sparseml.transformers.integration_helper_functions_generative # noqa F401 + + transformers_gen = IntegrationHelperFunctions.load_from_registry( + Integrations.transformers_generative.value + ) + assert transformers_gen.create_model + assert transformers_gen.create_dummy_input + assert transformers_gen.export + assert transformers_gen.graph_optimizations.values() == ["apply_kv_cache_injection"] + assert transformers_gen.create_data_samples + assert set(transformers_gen.deployment_directory_files_mandatory) == { + "model.onnx", + "tokenizer_config.json", + "config.json", + "special_tokens_map.json", + "vocab.json", + "merges.txt", + } + assert set(transformers_gen.deployment_directory_files_optional) == { + "tokenizer.json", + "tokenizer.model", + } diff --git a/tests/sparseml/transformers/utils/test_helpers.py b/tests/sparseml/transformers/utils/test_helpers.py index e2413f492ae..b657125c09d 100644 --- a/tests/sparseml/transformers/utils/test_helpers.py +++ b/tests/sparseml/transformers/utils/test_helpers.py @@ -12,12 +12,110 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict + import pytest +import torch -from sparseml.transformers.utils.helpers import save_zoo_directory +from huggingface_hub import snapshot_download +from sparseml.transformers.utils.helpers import ( + is_transformer_generative_model, + is_transformer_model, + run_transformers_inference, + save_zoo_directory, +) +from sparseml.transformers.utils.initializers import initialize_config, initialize_model from sparsezoo import Model +@pytest.fixture() +def generative_model_path(tmp_path): + return snapshot_download("roneneldan/TinyStories-1M", local_dir=tmp_path) + + +@pytest.fixture() +def model_path(tmp_path): + return Model( + "zoo:mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block_quantized", + tmp_path, + ).training.path + + +@pytest.fixture() +def sequence_length(): + return 384 + + +@pytest.fixture() +def dummy_inputs(): + input_ids = torch.zeros((1, 32), dtype=torch.int64) + attention_mask = torch.ones((1, 32), dtype=torch.int64) + + return OrderedDict( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + +@pytest.mark.parametrize( + "stub", + [ + "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none", # noqa E501 + ], +) +def test_is_transformer_model(tmp_path, stub): + zoo_model = Model(stub, tmp_path) + source_path = zoo_model.training.path + assert is_transformer_model(source_path) + + +def test_is_transformer_generative_model(generative_model_path): + assert is_transformer_generative_model(generative_model_path) + + +def test_run_transformers_inference_generative(generative_model_path, dummy_inputs): + config = initialize_config( + model_path=generative_model_path, + trust_remote_code=True, + **dict(use_cache=False), + ) + model = initialize_model( + model_path=generative_model_path, + task="text-generation", + config=config, + ) + + inputs, label, output = run_transformers_inference(inputs=dummy_inputs, model=None) + assert isinstance(inputs, dict) + assert label is None + assert output is None + + inputs, label, output = run_transformers_inference(inputs=dummy_inputs, model=model) + assert isinstance(inputs, dict) + assert label is None + assert isinstance(output, dict) + + +def test_run_tranformers_inference(model_path, dummy_inputs): + + config = initialize_config(model_path=model_path, trust_remote_code=True) + model = initialize_model( + model_path=model_path, + task="qa", + config=config, + ) + + inputs, label, output = run_transformers_inference(inputs=dummy_inputs, model=None) + assert isinstance(inputs, dict) + assert label is None + assert output is None + + inputs, label, output = run_transformers_inference(inputs=dummy_inputs, model=model) + assert isinstance(inputs, dict) + assert label is None + assert isinstance(output, dict) + + @pytest.mark.parametrize( "stub", [ diff --git a/tests/sparseml/transformers/utils/test_initializers.py b/tests/sparseml/transformers/utils/test_initializers.py new file mode 100644 index 00000000000..cf4bb9641a4 --- /dev/null +++ b/tests/sparseml/transformers/utils/test_initializers.py @@ -0,0 +1,141 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +import pytest + +from sparseml.pytorch.utils.helpers import default_device, use_single_gpu +from sparseml.transformers.utils.load_task_dataset import load_task_dataset +from sparsezoo import Model +from src.sparseml.transformers.utils.initializers import ( + initialize_config, + initialize_model, + initialize_tokenizer, + initialize_trainer, +) + + +@pytest.mark.parametrize( + "stub, task, data_args", + [ + ( + "zoo:obert-medium-squad_wikipedia_bookcorpus-pruned95_quantized", + "qa", + dict(dataset_name="squad"), + ), + ( + "zoo:distilbert-qqp_wikipedia_bookcorpus-pruned80.4block_quantized", + "text-classification", + None, + ), + ], + scope="class", +) +@pytest.mark.parametrize("device", ["auto", "cpu", None], scope="class") +class TestInitializeModelFlow: + @pytest.fixture() + def setup(self, tmp_path, stub, task, data_args, device): + self.model_path = Model(stub, tmp_path).training.path + self.sequence_length = 384 + self.task = task + self.data_args = data_args + + # process device argument + device = default_device() if device == "auto" else device + # if multiple gpus available use the first one + if not (device is None or device == "cpu"): + device = use_single_gpu(device) + self.device = device + yield + shutil.rmtree(tmp_path) + + def test_initialize_config(self, setup): + assert initialize_config(model_path=self.model_path, trust_remote_code=True) + + def test_initialize_tokenizer(self, setup): + tokenizer = initialize_tokenizer( + self.model_path, self.sequence_length, self.task + ) + assert ( + tokenizer.padding_side == "right" + if self.task != "text-generation" + else "left" + ) + assert tokenizer.model_max_length == self.sequence_length + + def test_initialize_model(self, setup): + model = initialize_model( + model_path=self.model_path, + task=self.task, + device=self.device, + config=initialize_config( + model_path=self.model_path, trust_remote_code=True + ), + ) + assert model + self._test_model_device(model) + + def test_initialize_trainer(self, setup): + if not self.data_args: + pytest.skip("To run this test, please provide valid data_args") + config = initialize_config(model_path=self.model_path, trust_remote_code=True) + model = initialize_model( + model_path=self.model_path, + task=self.task, + device=self.device, + config=config, + ) + tokenizer = initialize_tokenizer( + self.model_path, self.sequence_length, self.task + ) + dataset = load_task_dataset( + task=self.task, + tokenizer=tokenizer, + data_args=self.data_args, + model=model, + config=config, + ) + validation_dataset = dataset.get("validation") + + trainer = initialize_trainer( + model=model, + model_path=self.model_path, + validation_dataset=validation_dataset, + ) + # assert that trainer is not messing up with model's location + self._test_model_device(model) + assert trainer.get_eval_dataloader() + + def test_initialize_trainer_no_validation_dataset(self, setup): + config = initialize_config(model_path=self.model_path, trust_remote_code=True) + tokenizer = initialize_tokenizer( + self.model_path, self.sequence_length, self.task + ) + model = initialize_model( + model_path=self.model_path, + task=self.task, + config=config, + ) + trainer = initialize_trainer( + model=model, model_path=self.model_path, validation_dataset=None + ) + assert trainer.eval_dataset is None + assert trainer._get_fake_dataloader(num_samples=10, tokenizer=tokenizer) + + def _test_model_device(self, model): + if model.device.type == "cuda": + assert self.device.startswith("cuda") + else: + assert self.device is None or self.device == "cpu"