Skip to content

Commit

Permalink
Keep model in CPU during ONNX export (#1586)
Browse files Browse the repository at this point in the history
  • Loading branch information
natuan committed May 31, 2023
1 parent 53e2f8d commit 9c7c285
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
15 changes: 14 additions & 1 deletion src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@
import math
import os
import shutil
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

from torch.nn import Module
from transformers import AutoConfig, AutoTokenizer
from transformers import TrainingArguments as HFTrainingArgs
from transformers.tokenization_utils_base import PaddingStrategy

from sparseml.optim import parse_recipe_variables
Expand Down Expand Up @@ -107,6 +109,14 @@
_LOGGER = logging.getLogger(__name__)


@dataclass
class DeviceCPUTrainingArgs(HFTrainingArgs):
@property
def place_model_on_device(self):
# Ensure model remains in CPU during ONNX export
return False


def load_task_model(task: str, model_path: str, config: Any) -> Module:
if task == "masked-language-modeling" or task == "mlm":
return SparseAutoModel.masked_language_modeling_from_pretrained(
Expand Down Expand Up @@ -294,15 +304,18 @@ def export_transformer_to_onnx(
_LOGGER.info(f"loaded validation dataset for args {data_args}")

model = model.train()

args = DeviceCPUTrainingArgs(output_dir="tmp_trainer")
trainer = Trainer(
model=model,
args=args,
model_state_path=model_path,
eval_dataset=eval_dataset,
recipe=None,
recipe_args=None,
teacher=None,
)
model = model.cpu()

applied = trainer.apply_manager(epoch=math.inf, checkpoint=None)

if not applied:
Expand Down
4 changes: 1 addition & 3 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
training_args_dict=training_args.to_dict(),
data_args_dict=asdict(data_args) if data_args else {},
)
if training_args
if training_args and metadata_args
else None
)

Expand Down Expand Up @@ -762,7 +762,6 @@ def _get_fake_dataloader(
num_samples: int,
tokenizer: "PreTrainedTokenizerBase", # noqa: F821
):

# Rearrange inputs' keys to match those defined by model foward func, which
# seem to define how the order of inputs is determined in the exported model
forward_args_spec = inspect.getfullargspec(self.model.__class__.forward)
Expand Down Expand Up @@ -820,7 +819,6 @@ def __init__(
teacher: Optional[Union[Module, str]] = None,
**kwargs,
):

super().__init__(
model=model,
model_state_path=model_state_path,
Expand Down

0 comments on commit 9c7c285

Please sign in to comment.