Skip to content

Commit

Permalink
raise exception for GMPruningModifier args that are no longer support…
Browse files Browse the repository at this point in the history
…ed (#640) (#641)
  • Loading branch information
bfineran committed Mar 23, 2022
1 parent 77166e4 commit 8096e15
Showing 1 changed file with 25 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ class GMPruningModifier(BaseGradualPruningModifier, BaseGMPruningModifier):
:param global_sparsity: set True to use global magnitude pruning, False for
layer-wise. Default is False. [DEPRECATED] - use GlobalMagnitudePruningModifier
for global magnitude pruning and MagnitudePruningModifier for layer-wise
:param phased: NO LONGER SUPPORTED - former parameter for AC/DC pruning. Will raise
an exception if set to True. Use ACDCPruningModifier for AC/DC pruning
:param score_type: NO LONGER SUPPORTED - former parameter for using different
sparsification algorithms, will raise an exception if set to the non default
value
"""

def __init__(
Expand All @@ -128,8 +133,10 @@ def __init__(
log_types: Union[str, List[str]] = ALL_TOKEN,
mask_type: str = "unstructured",
global_sparsity: bool = False,
phased: bool = False,
score_type: str = "magnitude",
):
self._check_warn_global_sparsity(global_sparsity)
self._check_deprecated_params(global_sparsity, phased, score_type)

super(GMPruningModifier, self).__init__(
params=params,
Expand Down Expand Up @@ -180,13 +187,29 @@ def global_sparsity(self) -> bool:
"""
return self._global_sparsity

def _check_warn_global_sparsity(self, global_sparsity):
def _check_deprecated_params(
self,
global_sparsity: bool,
phased: bool,
score_type: str,
):
if self.__class__.__name__ == "GMPruningModifier" and global_sparsity is True:
_LOGGER.warning(
"Use of global_sparsity parameter in GMPruningModifier is now "
"deprecated. Use GlobalMagnitudePruningModifier instead for global "
"magnitude pruning"
)
if phased:
raise ValueError(
f"Use of phased=True in {self.__class__.__name__} is no longer "
"supported use the ACDCPruningModifier for phased (AC/DC) pruning"
)
if score_type != "magnitude":
raise ValueError(
"use of score_type to specify a sparsification algorithm is no longer "
"supported. Use the specific pruning modifier for the desired "
f"sparsification algorithm instead. Found score_type={score_type}"
)


@PyTorchModifierYAML()
Expand Down

0 comments on commit 8096e15

Please sign in to comment.