Skip to content

Commit

Permalink
Update training_tricks.py (#3151)
Browse files Browse the repository at this point in the history
* Update training_tricks.py

* pep

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 26, 2020
1 parent cb0c60b commit 0112355
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def scale_batch_size(self,
def __scale_batch_dump_params(self):
# Prevent going into infinite loop
self.__dumped_params = {
'auto_lr_find': self.auto_lr_find,
'max_steps': self.max_steps,
'weights_summary': self.weights_summary,
'logger': self.logger,
Expand All @@ -226,6 +227,7 @@ def __scale_batch_dump_params(self):

def __scale_batch_reset_params(self, model, steps_per_trial):
self.auto_scale_batch_size = None # prevent recursion
self.auto_lr_find = False # avoid lr find being called multiple times
self.max_steps = steps_per_trial # take few steps
self.weights_summary = None # not needed before full run
self.logger = DummyLogger()
Expand All @@ -237,6 +239,7 @@ def __scale_batch_reset_params(self, model, steps_per_trial):
self.model = model # required for saving

def __scale_batch_restore_params(self):
self.auto_lr_find = self.__dumped_params['auto_lr_find']
self.max_steps = self.__dumped_params['max_steps']
self.weights_summary = self.__dumped_params['weights_summary']
self.logger = self.__dumped_params['logger']
Expand Down

0 comments on commit 0112355

Please sign in to comment.