Skip to content

Commit

Permalink
PyTorch modifier fixes (#239) (#240)
Browse files Browse the repository at this point in the history
- allow block_shape to be more than two dimensions since on export it changes to weight shape
- move initialized check under try catch in delete for safety
  • Loading branch information
markurtz committed May 13, 2021
1 parent 62b1938 commit 97bf486
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
16 changes: 13 additions & 3 deletions src/sparseml/pytorch/optim/mask_creator_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,13 +474,23 @@ def __init__(
block_shape: List[int],
grouping_fn_name: str = "mean",
):
if len(block_shape) != 2:
if len(block_shape) < 2:
raise ValueError(
(
"Invalid block_shape: {}"
" ,block_shape must have length == 2 for in and out channels"
"Invalid block_shape: {}, "
"block_shape must have length == 2 for in and out channels"
).format(block_shape)
)

if len(block_shape) > 2 and not all([shape == 1 for shape in block_shape[2:]]):
# after in and out channels, only 1 can be used for other dimensions
raise ValueError(
(
"Invalid block_shape: {}, "
"block_shape for indices not in [0, 1] must be equal to 1"
).format(block_shape)
)

self._block_shape = block_shape
self._grouping_fn_name = grouping_fn_name

Expand Down
5 changes: 2 additions & 3 deletions src/sparseml/pytorch/optim/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,9 @@ def __init__(self, log_types: Union[str, List[str]] = None, **kwargs):
self._loggers = None

def __del__(self):
if not self.initialized:
return

try:
if not self.initialized:
return
self.finalize()
except Exception:
pass
Expand Down

0 comments on commit 97bf486

Please sign in to comment.