Skip to content

Commit

Permalink
Fix export of all quantized transformer models (#1654)
Browse files Browse the repository at this point in the history
* Expose trust_remote_code flag for HF-transformers

* Reload big model with multiple state dict files

* Add description for reload func

* handle new HF interface

---------

Co-authored-by: Tuan Nguyen <tuan@neuralmagic.com>
  • Loading branch information
eldarkurtic and natuan committed Jul 7, 2023
1 parent 718c7f4 commit 4ec5133
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
22 changes: 20 additions & 2 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,23 @@ def place_model_on_device(self):
return False


def load_task_model(task: str, model_path: str, config: Any) -> Module:
def load_task_model(
task: str, model_path: str, config: Any, trust_remote_code: bool = False
) -> Module:
if task == "masked-language-modeling" or task == "mlm":
return SparseAutoModel.masked_language_modeling_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

if task == "question-answering" or task == "qa":
return SparseAutoModel.question_answering_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

if (
Expand All @@ -142,20 +146,23 @@ def load_task_model(task: str, model_path: str, config: Any) -> Module:
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

if task == "token-classification" or task == "ner":
return SparseAutoModel.token_classification_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

if task == "text-generation":
return SparseAutoModel.text_generation_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

raise ValueError(f"unrecognized task given of {task}")
Expand Down Expand Up @@ -236,6 +243,7 @@ def export_transformer_to_onnx(
finetuning_task: Optional[str] = None,
onnx_file_name: str = MODEL_ONNX_NAME,
num_export_samples: int = 0,
trust_remote_code: bool = False,
data_args: Optional[Union[Dict[str, Any], str]] = None,
one_shot: Optional[str] = None,
) -> str:
Expand All @@ -255,6 +263,7 @@ def export_transformer_to_onnx(
is model.onnx. Note that when loading a model directory to a deepsparse
pipeline, it will look only for 'model.onnx'
:param num_export_samples: number of samples (inputs/outputs) to export
:param trust_remote_code: set True to allow custom models in HF-transformers
:param data_args: additional args to instantiate a `DataTrainingArguments`
instance for exporting samples
:param one_shot: one shot recipe to be applied before exporting model
Expand All @@ -280,6 +289,7 @@ def export_transformer_to_onnx(
config_args = {"finetuning_task": finetuning_task} if finetuning_task else {}
config = AutoConfig.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
**config_args,
)
tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -288,7 +298,7 @@ def export_transformer_to_onnx(
if task == "text-generation":
tokenizer.pad_token = tokenizer.eos_token

model = load_task_model(task, model_path, config)
model = load_task_model(task, model_path, config, trust_remote_code)
_LOGGER.info(f"loaded model, config, and tokenizer from {model_path}")

eval_dataset = None
Expand Down Expand Up @@ -547,6 +557,11 @@ def _parse_args() -> argparse.Namespace:
help="local path or SparseZoo stub to a recipe that should be applied "
"in a one-shot manner before exporting",
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help=("Set flag to allow custom models in HF-transformers"),
)

return parser.parse_args()

Expand All @@ -559,6 +574,7 @@ def export(
finetuning_task: str,
onnx_file_name: str,
num_export_samples: int = 0,
trust_remote_code: bool = False,
data_args: Optional[str] = None,
one_shot: Optional[str] = None,
):
Expand All @@ -570,6 +586,7 @@ def export(
finetuning_task=finetuning_task,
onnx_file_name=onnx_file_name,
num_export_samples=num_export_samples,
trust_remote_code=trust_remote_code,
data_args=data_args,
one_shot=one_shot,
)
Expand All @@ -593,6 +610,7 @@ def main():
finetuning_task=args.finetuning_task,
onnx_file_name=args.onnx_file_name,
num_export_samples=args.num_export_samples,
trust_remote_code=args.trust_remote_code,
data_args=args.data_args,
one_shot=args.one_shot,
)
Expand Down
10 changes: 8 additions & 2 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,11 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
dd = torch.load(os.path.join(load_path, f), map_location="cpu")
loaded_state_dict.update(dd)

_, missing, unexpected, _, _ = self.model._load_pretrained_model(
_, missing, unexpected, mismatched, _, _ = self.model._load_pretrained_model(
model=self.model,
state_dict=loaded_state_dict,
loaded_keys=list(loaded_state_dict.keys()),
resolved_archive_file=[],
resolved_archive_file=None,
pretrained_model_name_or_path=load_path,
_fast_init=False,
)
Expand All @@ -704,6 +704,12 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
f"{unexpected}"
)

if mismatched:
_LOGGER.warning(
f"Mismatched keys found when reloading model state for SparseML recipe:"
f"{mismatched}"
)

total_loaded = len(current_state_dict) - (len(missing) if len(missing) else 0)
_LOGGER.info(
f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}"
Expand Down

0 comments on commit 4ec5133

Please sign in to comment.