Skip to content

Commit

Permalink
DDP: Collect gradients in all processes (#1293)
Browse files Browse the repository at this point in the history
  • Loading branch information
eldarkurtic authored Jan 6, 2023
1 parent b76f25a commit ef277e5
Showing 1 changed file with 38 additions and 40 deletions.
78 changes: 38 additions & 40 deletions src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def initialize(
:param kwargs: optional kwargs to support specific arguments
for individual modifiers.
"""
_LOGGER.info("Initializing OBSPruningModifier")
if (
"grad_sampler" not in kwargs
or "data_loader_builder" not in kwargs["grad_sampler"]
Expand All @@ -243,60 +242,59 @@ def initialize(
"must be provided to initialize GradSampler"
)

super().initialize(module, epoch, loggers, **kwargs)
self._grad_sampler = GradSampler(
kwargs["grad_sampler"]["data_loader_builder"](self._grad_sampler_kwargs),
kwargs["grad_sampler"]["loss_function"],
)

if self._scorer._is_main_proc: # grads collected only in the main proc
self._grad_sampler = GradSampler(
kwargs["grad_sampler"]["data_loader_builder"](
self._grad_sampler_kwargs
),
kwargs["grad_sampler"]["loss_function"],
)
super().initialize(module, epoch, loggers, **kwargs)

def check_mask_update(
self, module: Module, epoch: float, steps_per_epoch: int, **kwargs
):
if steps_per_epoch == 1 and not math.isinf(epoch):
return # not a one-shot run

_LOGGER.info("Running OBS Pruning")
torch.cuda.empty_cache()
if self._scorer._is_main_proc:
self._pre_step_completed = True
_LOGGER.info("Running OBS Pruning")
self._scorer._enabled_grad_buffering = True
to_apply_sparsities = self.get_applied_sparsity_for_epoch(
epoch, steps_per_epoch
)
last_applied_sparsities = (
self._last_applied_sparsity
if isinstance(self._last_applied_sparsity, List)
else [self._last_applied_sparsity] * len(to_apply_sparsities)
)

for i in range(1, self._num_recomputations + 1):
self._collect_grad_samples(module, self._grad_sampler)
recomputation_sparsity = [
interpolate(
i,
0,
self._num_recomputations,
start_sparsity,
target_sparsity,
)
for start_sparsity, target_sparsity in zip(
last_applied_sparsities, to_apply_sparsities
)
]
super().check_mask_update(
module,
epoch,
steps_per_epoch,
recomputation_sparsity=recomputation_sparsity,
self._pre_step_completed = True
to_apply_sparsities = self.get_applied_sparsity_for_epoch(
epoch, steps_per_epoch
)
last_applied_sparsities = (
self._last_applied_sparsity
if isinstance(self._last_applied_sparsity, List)
else [self._last_applied_sparsity] * len(to_apply_sparsities)
)

for i in range(1, self._num_recomputations + 1):
self._collect_grad_samples(module, self._grad_sampler)
recomputation_sparsity = [
interpolate(
i,
0,
self._num_recomputations,
start_sparsity,
target_sparsity,
)
for start_sparsity, target_sparsity in zip(
last_applied_sparsities, to_apply_sparsities
)
]
super().check_mask_update(
module,
epoch,
steps_per_epoch,
recomputation_sparsity=recomputation_sparsity,
)

torch.cuda.empty_cache()
torch.cuda.empty_cache()
self._last_applied_sparsity = to_apply_sparsities
if self._scorer._is_main_proc:
self._scorer._enabled_grad_buffering = False
self._last_applied_sparsity = to_apply_sparsities

def _get_mask_creator(
self, param_names: List[str], params: List[Parameter]
Expand Down

0 comments on commit ef277e5

Please sign in to comment.