Skip to content

Commit

Permalink
Fix for OOM Errors during Ultrachat200k Finetuning (#2180) (#2181)
Browse files Browse the repository at this point in the history
* testing fix

* get rid of repeated log

* revert yaml
  • Loading branch information
Satrat committed Mar 14, 2024
1 parent d64d9fb commit 3bf79ad
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
3 changes: 3 additions & 0 deletions src/sparseml/modifiers/pruning/constant/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def on_update(self, state: State, event: Event, **kwargs):
def apply_masks(module):
mask_name = param_mask_name()
if hasattr(module, mask_name):
mask = getattr(module, mask_name)
if mask.device != module.weight.device:
setattr(module, mask_name, mask.to(module.weight.device))
module.weight *= getattr(module, mask_name)

state.model.model.apply(apply_masks)
Expand Down
8 changes: 0 additions & 8 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
)
from sparseml.transformers.finetune.model_args import ModelArguments
from sparseml.transformers.finetune.training_args import TrainingArguments
from sparseml.utils.fsdp.context import summon_full_params_context
from sparseml.utils.fsdp.helpers import is_fsdp_model, unwrap_and_export_model
from sparseml.utils.pytorch import qat_active


_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -287,12 +285,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
session = session_manager.active_session()
session.reset_stage()

# log model sparsity
with summon_full_params_context(self.trainer.model):
if self.trainer.accelerator.is_main_process:
if not qat_active(self.trainer.model):
self.trainer.log_model_sparsification()

# synchronize and clean up memory
self.trainer.accelerator.wait_for_everyone()
self.trainer.model = get_session_model()
Expand Down
19 changes: 15 additions & 4 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from sparseml.utils.fsdp.context import summon_full_params_context
from sparseml.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
from sparseml.utils.pytorch import qat_active


__all__ = [
Expand Down Expand Up @@ -137,7 +138,7 @@ def initialize_session(
train_data = self.get_train_dataloader()

self.accelerator.wait_for_everyone()
with summon_full_params_context(self.model):
with summon_full_params_context(self.model, offload_to_cpu=True):
session_manager.initialize(
model=self.model,
teacher_model=self.teacher, # TODO: what about for self/disable?
Expand Down Expand Up @@ -370,9 +371,13 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):

self.accelerator.wait_for_everyone()

# Need to gather parameters across the GPUs before accessing layer weights
with summon_full_params_context(self.model):
self.log_model_sparsification()
# log model sparsity
with summon_full_params_context(self.model, offload_to_cpu=True):
if self.accelerator.is_main_process:
if not qat_active(self.model):
self.log_model_sparsification()

self.accelerator.wait_for_everyone()

return output

Expand Down Expand Up @@ -434,6 +439,12 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
accelerator=self.accelerator,
)

# log model sparsity
with summon_full_params_context(self.model, offload_to_cpu=True):
if self.accelerator.is_main_process:
if not qat_active(self.model):
self.log_model_sparsification()

self.accelerator.wait_for_everyone()

def save_model(
Expand Down

0 comments on commit 3bf79ad

Please sign in to comment.