diff --git a/src/sparseml/modifiers/pruning/constant/pytorch.py b/src/sparseml/modifiers/pruning/constant/pytorch.py index c18cbd9dc36..a45fd603e27 100644 --- a/src/sparseml/modifiers/pruning/constant/pytorch.py +++ b/src/sparseml/modifiers/pruning/constant/pytorch.py @@ -71,6 +71,9 @@ def on_update(self, state: State, event: Event, **kwargs): def apply_masks(module): mask_name = param_mask_name() if hasattr(module, mask_name): + mask = getattr(module, mask_name) + if mask.device != module.weight.device: + setattr(module, mask_name, mask.to(module.weight.device)) module.weight *= getattr(module, mask_name) state.model.model.apply(apply_masks) diff --git a/src/sparseml/transformers/finetune/runner.py b/src/sparseml/transformers/finetune/runner.py index f57eec3c945..cbcf9cea71f 100644 --- a/src/sparseml/transformers/finetune/runner.py +++ b/src/sparseml/transformers/finetune/runner.py @@ -40,9 +40,7 @@ ) from sparseml.transformers.finetune.model_args import ModelArguments from sparseml.transformers.finetune.training_args import TrainingArguments -from sparseml.utils.fsdp.context import summon_full_params_context from sparseml.utils.fsdp.helpers import is_fsdp_model, unwrap_and_export_model -from sparseml.utils.pytorch import qat_active _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -287,12 +285,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): session = session_manager.active_session() session.reset_stage() - # log model sparsity - with summon_full_params_context(self.trainer.model): - if self.trainer.accelerator.is_main_process: - if not qat_active(self.trainer.model): - self.trainer.log_model_sparsification() - # synchronize and clean up memory self.trainer.accelerator.wait_for_everyone() self.trainer.model = get_session_model() diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 8df8515de34..97a1debec42 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -40,6 +40,7 @@ ) from sparseml.utils.fsdp.context import summon_full_params_context from sparseml.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp +from sparseml.utils.pytorch import qat_active __all__ = [ @@ -137,7 +138,7 @@ def initialize_session( train_data = self.get_train_dataloader() self.accelerator.wait_for_everyone() - with summon_full_params_context(self.model): + with summon_full_params_context(self.model, offload_to_cpu=True): session_manager.initialize( model=self.model, teacher_model=self.teacher, # TODO: what about for self/disable? @@ -370,9 +371,13 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): self.accelerator.wait_for_everyone() - # Need to gather parameters across the GPUs before accessing layer weights - with summon_full_params_context(self.model): - self.log_model_sparsification() + # log model sparsity + with summon_full_params_context(self.model, offload_to_cpu=True): + if self.accelerator.is_main_process: + if not qat_active(self.model): + self.log_model_sparsification() + + self.accelerator.wait_for_everyone() return output @@ -434,6 +439,12 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None): accelerator=self.accelerator, ) + # log model sparsity + with summon_full_params_context(self.model, offload_to_cpu=True): + if self.accelerator.is_main_process: + if not qat_active(self.model): + self.log_model_sparsification() + self.accelerator.wait_for_everyone() def save_model(