Skip to content

Commit

Permalink
Revert "Fix onnx export for quantized MPT models 2 (#1647)" (#1683)
Browse files Browse the repository at this point in the history
This reverts commit e8ff0f6.
  • Loading branch information
rahul-tuli committed Jul 21, 2023
1 parent 925e135 commit 09c9c96
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions src/sparseml/pytorch/sparsification/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def _match_submodule_name_or_type(
submodule_match = ""
for name_or_type in names_or_types:
name_to_compare = submodule_name[:]
name_to_compare = _maybe_discard_prefix(name_to_compare)
if name_to_compare.startswith("module."):
name_to_compare = name_to_compare[7:]
if name_or_type == submodule.__class__.__name__:
# type match, return type name
return name_or_type
Expand Down Expand Up @@ -425,7 +426,8 @@ def _get_unmatched_types_or_names(types_or_names):
matched = False
for submodule_name, submodule in model.named_modules():
name_to_compare = submodule_name[:]
name_to_compare = _maybe_discard_prefix(name_to_compare)
if name_to_compare.startswith("module."):
name_to_compare = name_to_compare[7:]
if name_to_compare.startswith(type_or_name) or (
submodule.__class__.__name__ == type_or_name
):
Expand All @@ -451,10 +453,3 @@ def _build_error_str(property_name, unmatched_values):
unmatched_ignore = _get_unmatched_types_or_names(ignore)
if unmatched_ignore:
raise ValueError(_build_error_str("ignore", unmatched_ignore))


def _maybe_discard_prefix(name: str) -> str:
for prefix in ["module.", "model."]:
if name.startswith(prefix):
name = name.replace(prefix, "", 1)
return name

0 comments on commit 09c9c96

Please sign in to comment.