diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index b36ef5ee85d..fa94e6be94f 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -75,10 +75,12 @@ import math import os import shutil +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union from torch.nn import Module from transformers import AutoConfig, AutoTokenizer +from transformers import TrainingArguments as HFTrainingArgs from transformers.tokenization_utils_base import PaddingStrategy from sparseml.optim import parse_recipe_variables @@ -107,6 +109,14 @@ _LOGGER = logging.getLogger(__name__) +@dataclass +class DeviceCPUTrainingArgs(HFTrainingArgs): + @property + def place_model_on_device(self): + # Ensure model remains in CPU during ONNX export + return False + + def load_task_model(task: str, model_path: str, config: Any) -> Module: if task == "masked-language-modeling" or task == "mlm": return SparseAutoModel.masked_language_modeling_from_pretrained( @@ -294,15 +304,18 @@ def export_transformer_to_onnx( _LOGGER.info(f"loaded validation dataset for args {data_args}") model = model.train() + + args = DeviceCPUTrainingArgs(output_dir="tmp_trainer") trainer = Trainer( model=model, + args=args, model_state_path=model_path, eval_dataset=eval_dataset, recipe=None, recipe_args=None, teacher=None, ) - model = model.cpu() + applied = trainer.apply_manager(epoch=math.inf, checkpoint=None) if not applied: diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 9ea0ce8bfc0..566966de772 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -125,7 +125,7 @@ def __init__( training_args_dict=training_args.to_dict(), data_args_dict=asdict(data_args) if data_args else {}, ) - if training_args + if training_args and metadata_args else None ) @@ -762,7 +762,6 @@ def _get_fake_dataloader( num_samples: int, tokenizer: "PreTrainedTokenizerBase", # noqa: F821 ): - # Rearrange inputs' keys to match those defined by model foward func, which # seem to define how the order of inputs is determined in the exported model forward_args_spec = inspect.getfullargspec(self.model.__class__.forward) @@ -820,7 +819,6 @@ def __init__( teacher: Optional[Union[Module, str]] = None, **kwargs, ): - super().__init__( model=model, model_state_path=model_state_path,