Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix export of all quantized transformer models #1654

Merged
merged 9 commits into from
Jul 7, 2023
22 changes: 20 additions & 2 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,23 @@ def place_model_on_device(self):
return False


def load_task_model(task: str, model_path: str, config: Any) -> Module:
def load_task_model(
task: str, model_path: str, config: Any, trust_remote_code: bool = False
) -> Module:
if task == "masked-language-modeling" or task == "mlm":
return SparseAutoModel.masked_language_modeling_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

if task == "question-answering" or task == "qa":
return SparseAutoModel.question_answering_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

if (
Expand All @@ -142,20 +146,23 @@ def load_task_model(task: str, model_path: str, config: Any) -> Module:
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

if task == "token-classification" or task == "ner":
return SparseAutoModel.token_classification_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

if task == "text-generation":
return SparseAutoModel.text_generation_from_pretrained(
model_name_or_path=model_path,
config=config,
model_type="model",
trust_remote_code=trust_remote_code,
)

raise ValueError(f"unrecognized task given of {task}")
Expand Down Expand Up @@ -236,6 +243,7 @@ def export_transformer_to_onnx(
finetuning_task: Optional[str] = None,
onnx_file_name: str = MODEL_ONNX_NAME,
num_export_samples: int = 0,
trust_remote_code: bool = False,
data_args: Optional[Union[Dict[str, Any], str]] = None,
one_shot: Optional[str] = None,
) -> str:
Expand All @@ -255,6 +263,7 @@ def export_transformer_to_onnx(
is model.onnx. Note that when loading a model directory to a deepsparse
pipeline, it will look only for 'model.onnx'
:param num_export_samples: number of samples (inputs/outputs) to export
:param trust_remote_code: set True to allow custom models in HF-transformers
:param data_args: additional args to instantiate a `DataTrainingArguments`
instance for exporting samples
:param one_shot: one shot recipe to be applied before exporting model
Expand All @@ -280,6 +289,7 @@ def export_transformer_to_onnx(
config_args = {"finetuning_task": finetuning_task} if finetuning_task else {}
config = AutoConfig.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
**config_args,
)
tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -288,7 +298,7 @@ def export_transformer_to_onnx(
if task == "text-generation":
tokenizer.pad_token = tokenizer.eos_token

model = load_task_model(task, model_path, config)
model = load_task_model(task, model_path, config, trust_remote_code)
_LOGGER.info(f"loaded model, config, and tokenizer from {model_path}")

eval_dataset = None
Expand Down Expand Up @@ -547,6 +557,11 @@ def _parse_args() -> argparse.Namespace:
help="local path or SparseZoo stub to a recipe that should be applied "
"in a one-shot manner before exporting",
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help=("Set flag to allow custom models in HF-transformers"),
)

return parser.parse_args()

Expand All @@ -559,6 +574,7 @@ def export(
finetuning_task: str,
onnx_file_name: str,
num_export_samples: int = 0,
trust_remote_code: bool = False,
data_args: Optional[str] = None,
one_shot: Optional[str] = None,
):
Expand All @@ -570,6 +586,7 @@ def export(
finetuning_task=finetuning_task,
onnx_file_name=onnx_file_name,
num_export_samples=num_export_samples,
trust_remote_code=trust_remote_code,
data_args=data_args,
one_shot=one_shot,
)
Expand All @@ -593,6 +610,7 @@ def main():
finetuning_task=args.finetuning_task,
onnx_file_name=args.onnx_file_name,
num_export_samples=args.num_export_samples,
trust_remote_code=args.trust_remote_code,
data_args=args.data_args,
one_shot=args.one_shot,
)
Expand Down
10 changes: 8 additions & 2 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,11 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
dd = torch.load(os.path.join(load_path, f), map_location="cpu")
loaded_state_dict.update(dd)

_, missing, unexpected, _, _ = self.model._load_pretrained_model(
_, missing, unexpected, mismatched, _, _ = self.model._load_pretrained_model(
model=self.model,
state_dict=loaded_state_dict,
loaded_keys=list(loaded_state_dict.keys()),
resolved_archive_file=[],
resolved_archive_file=None,
pretrained_model_name_or_path=load_path,
_fast_init=False,
)
Expand All @@ -704,6 +704,12 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
f"{unexpected}"
)

if mismatched:
_LOGGER.warning(
f"Mismatched keys found when reloading model state for SparseML recipe:"
f"{mismatched}"
)

total_loaded = len(current_state_dict) - (len(missing) if len(missing) else 0)
_LOGGER.info(
f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}"
Expand Down
Loading