Skip to content

Commit

Permalink
Fix onnx export for quantized MPT models 2 (#1647)
Browse files Browse the repository at this point in the history
MPT models appear as HF models wrapped additionally for Composer library compatibility which results in `model.` prefix (instead of the classic `module.`) At the moment this causes errors during export of quantized models as some modules are missed during matching of names. This PR fixes it.
  • Loading branch information
eldarkurtic committed Jul 14, 2023
1 parent 4f0a19e commit e8ff0f6
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/sparseml/pytorch/sparsification/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def _match_submodule_name_or_type(
submodule_match = ""
for name_or_type in names_or_types:
name_to_compare = submodule_name[:]
if name_to_compare.startswith("module."):
name_to_compare = name_to_compare[7:]
name_to_compare = _maybe_discard_prefix(name_to_compare)
if name_or_type == submodule.__class__.__name__:
# type match, return type name
return name_or_type
Expand Down Expand Up @@ -426,8 +425,7 @@ def _get_unmatched_types_or_names(types_or_names):
matched = False
for submodule_name, submodule in model.named_modules():
name_to_compare = submodule_name[:]
if name_to_compare.startswith("module."):
name_to_compare = name_to_compare[7:]
name_to_compare = _maybe_discard_prefix(name_to_compare)
if name_to_compare.startswith(type_or_name) or (
submodule.__class__.__name__ == type_or_name
):
Expand All @@ -453,3 +451,10 @@ def _build_error_str(property_name, unmatched_values):
unmatched_ignore = _get_unmatched_types_or_names(ignore)
if unmatched_ignore:
raise ValueError(_build_error_str("ignore", unmatched_ignore))


def _maybe_discard_prefix(name: str) -> str:
for prefix in ["module.", "model."]:
if name.startswith(prefix):
name = name.replace(prefix, "", 1)
return name

0 comments on commit e8ff0f6

Please sign in to comment.