Skip to content

Commit

Permalink
raise exception when not all grads have been supplied for WF pruning (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed Jun 21, 2021
1 parent 7676b27 commit 8502988
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/sparseml/pytorch/optim/mask_pruning_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ def score_parameters(self) -> List[Tensor]:
given by the OBS method. For the approximated Hessian inverse matrix
H^-1, scores will be W^2 / (2 * diag(H^-1))
"""

if self._grad_buffer is None or torch.any(
torch.all(self._grad_buffer == 0.0, dim=1)
):
# raise Exception if grad buffer is not full
raise RuntimeError(
"MFAC pruning step called, but not enough gradient samples have been "
f"collected. Expected {self._mfac_options.num_grads} samples"
)

if self._is_ddp:
# move all grads to one device
if self._is_main_proc:
Expand Down Expand Up @@ -450,10 +460,6 @@ def get_name() -> str:

def _score_parameters(self) -> List[Tensor]:
# score params using MFAC and the gathered grad buffers
if torch.any(torch.all(self._grads == 0.0, dim=1)):
# if not all grads are captured, return magnitudes as scores
return [torch.abs(param.data) for param in self._params]

# gather non-pruned weights
non_pruned_weights = torch.empty(self._grads.size(1)).to(self._grads.device)
weights_idx = 0
Expand Down

0 comments on commit 8502988

Please sign in to comment.