Skip to content

Commit

Permalink
fix errors in trainer save (#1213)
Browse files Browse the repository at this point in the history
Signed-off-by: Dillon Laird <dillonalaird@gmail.com>
  • Loading branch information
dillonalaird committed Jan 31, 2024
1 parent 9bc38ae commit ff501d0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,10 @@ def _save_checkpoint(self, model, trial, metrics=None):
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
else:
super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
super(GaudiLLaVATrainer, self)._save_checkpoint(model, trial, metrics)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
pass
else:
super(LLaVATrainer, self)._save(output_dir, state_dict)
super(GaudiLLaVATrainer, self)._save(output_dir, state_dict)
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
return

if trainer.deepspeed:
torch.cuda.synchronize()
if is_hpu_available:
import habana_frameworks.torch as ht
ht.hpu.synchronize()
else:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return

Expand Down

0 comments on commit ff501d0

Please sign in to comment.