Skip to content

Commit

Permalink
Fix logger bug and prepare data bug (#1933)
Browse files Browse the repository at this point in the history
* tests, fix logger bug and prepare data bug

* add CHANGELOG.md

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
  • Loading branch information
SkafteNicki and Nicki Skafte committed May 25, 2020
1 parent 033ddc0 commit a34eb9e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))

- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))

## [0.7.6] - 2020-05-16

### Added
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __init__(

self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size
self._is_data_prepared = False
self.replace_sampler_ddp = replace_sampler_ddp

self.truncated_bptt_steps = truncated_bptt_steps
Expand Down Expand Up @@ -823,17 +824,21 @@ def fit(
# download the data and do whatever transforms we need
# do before any spawn calls so that the model can assign properties
# only on proc 0 because no spawn has happened yet
model.prepare_data()
if not self._is_data_prepared:
model.prepare_data()
self._is_data_prepared = True

# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
model.logger = self.logger # reset logger binding

# Run learning rate finder:
if self.auto_lr_find:
self._run_lr_finder_internally(model)
model.logger = self.logger # reset logger binding

# route to appropriate start method
# when using multi-node or DDP within a node start each module in a separate process
Expand Down
23 changes: 23 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,26 @@ def test_suggestion_with_non_finite_values(tmpdir):

assert before_lr == after_lr, \
'Learning rate was altered because of non-finite loss values'


def test_logger_reset_correctly(tmpdir):
""" Test that logger is updated correctly """
tutils.reset_seed()

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)

trainer = Trainer(
default_save_path=tmpdir,
max_epochs=10,
auto_lr_find=True
)
logger1 = trainer.logger
trainer.fit(model)
logger2 = trainer.logger
logger3 = model.logger

assert logger1 == logger2, \
'Learning rate finder altered the logger of trainer'
assert logger2 == logger3, \
'Learning rate finder altered the logger of model'
23 changes: 23 additions & 0 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,26 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):

with pytest.raises(MisconfigurationException):
trainer.fit(model, **fit_options)


def test_logger_reset_correctly(tmpdir):
""" Test that logger is updated correctly """
tutils.reset_seed()

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)

trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
auto_scale_batch_size=True
)
logger1 = trainer.logger
trainer.fit(model)
logger2 = trainer.logger
logger3 = model.logger

assert logger1 == logger2, \
'Batch size finder altered the logger of trainer'
assert logger2 == logger3, \
'Batch size finder altered the logger of model'

0 comments on commit a34eb9e

Please sign in to comment.