Skip to content

Commit

Permalink
Enable export paths for LLMs [CodeGen, OPT, Bloom] (#1562)
Browse files Browse the repository at this point in the history
* initital commit

* missing config arg
  • Loading branch information
dbogunowicz committed May 19, 2023
1 parent ce43c7d commit 934ec0e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def load_task_model(task: str, model_path: str, config: Any) -> Module:
model_type="model",
)

if task == "text-generation":
return SparseAutoModel.text_generation_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
)

raise ValueError(f"unrecognized task given of {task}")


Expand Down Expand Up @@ -263,6 +270,9 @@ def export_transformer_to_onnx(
tokenizer = AutoTokenizer.from_pretrained(
model_path, model_max_length=sequence_length
)
if task == "text-generation":
tokenizer.pad_token = tokenizer.eos_token

model = load_task_model(task, model_path, config)
_LOGGER.info(f"loaded model, config, and tokenizer from {model_path}")

Expand Down Expand Up @@ -353,12 +363,14 @@ def export_transformer_to_onnx(
# run export
model = model.eval()
onnx_file_path = os.path.join(model_path, onnx_file_name)
kwargs = {"input_names": list(inputs.keys())} if task == "text-generation" else {}

export_onnx(
model,
inputs,
onnx_file_path,
convert_qat=convert_qat,
**kwargs,
)
_LOGGER.info(f"ONNX exported to {onnx_file_path}")

Expand Down
36 changes: 36 additions & 0 deletions src/sparseml/transformers/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from torch.nn import Module
from transformers import (
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
Expand Down Expand Up @@ -235,6 +236,41 @@ def text_classification_from_pretrained_distil(

return model, teacher

@staticmethod
def text_generation_from_pretrained(
model_name_or_path: str,
model_type: str,
**kwargs,
) -> Module:
"""
:param model_name_or_path: the name of or path to the model to load
:param model_type: specify the type of model loaded for logging;
ex one of [model, student, teacher]
:param kwargs: keyword arguments to pass through to the AutoModel call
:return: the created model for text generation
"""
SparseAutoModel._check_tf(model_name_or_path)
if not kwargs:
kwargs = {}
kwargs["from_tf"] = False
delayed = False
if "state_dict" not in kwargs:
kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict(
model_name_or_path
)
# Export decoder model without kv cache support
kwargs["config"].is_decoder = True
kwargs["config"].use_cache = False
kwargs["config"].use_past = False

model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
**kwargs,
)
SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed)

return model

@staticmethod
def token_classification_from_pretrained(
model_name_or_path: str,
Expand Down

0 comments on commit 934ec0e

Please sign in to comment.