Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep model in CPU during ONNX export #1586

Merged
merged 1 commit into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@
import math
import os
import shutil
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

from torch.nn import Module
from transformers import AutoConfig, AutoTokenizer
from transformers import TrainingArguments as HFTrainingArgs
from transformers.tokenization_utils_base import PaddingStrategy

from sparseml.optim import parse_recipe_variables
Expand Down Expand Up @@ -107,6 +109,14 @@
_LOGGER = logging.getLogger(__name__)


@dataclass
class DeviceCPUTrainingArgs(HFTrainingArgs):
@property
def place_model_on_device(self):
# Ensure model remains in CPU during ONNX export
return False


def load_task_model(task: str, model_path: str, config: Any) -> Module:
if task == "masked-language-modeling" or task == "mlm":
return SparseAutoModel.masked_language_modeling_from_pretrained(
Expand Down Expand Up @@ -294,15 +304,18 @@ def export_transformer_to_onnx(
_LOGGER.info(f"loaded validation dataset for args {data_args}")

model = model.train()

args = DeviceCPUTrainingArgs(output_dir="tmp_trainer")
trainer = Trainer(
model=model,
args=args,
model_state_path=model_path,
eval_dataset=eval_dataset,
recipe=None,
recipe_args=None,
teacher=None,
)
model = model.cpu()

applied = trainer.apply_manager(epoch=math.inf, checkpoint=None)

if not applied:
Expand Down
4 changes: 1 addition & 3 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
training_args_dict=training_args.to_dict(),
data_args_dict=asdict(data_args) if data_args else {},
)
if training_args
if training_args and metadata_args
else None
)

Expand Down Expand Up @@ -762,7 +762,6 @@ def _get_fake_dataloader(
num_samples: int,
tokenizer: "PreTrainedTokenizerBase", # noqa: F821
):

# Rearrange inputs' keys to match those defined by model foward func, which
# seem to define how the order of inputs is determined in the exported model
forward_args_spec = inspect.getfullargspec(self.model.__class__.forward)
Expand Down Expand Up @@ -820,7 +819,6 @@ def __init__(
teacher: Optional[Union[Module, str]] = None,
**kwargs,
):

super().__init__(
model=model,
model_state_path=model_state_path,
Expand Down
Loading