Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logger bug and prepare data bug #1933

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))

- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))

## [0.7.6] - 2020-05-16

### Added
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def __init__(

self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size
self._is_data_prepared = False
self.replace_sampler_ddp = replace_sampler_ddp

self.truncated_bptt_steps = truncated_bptt_steps
Expand Down Expand Up @@ -822,17 +823,21 @@ def fit(
# download the data and do whatever transforms we need
# do before any spawn calls so that the model can assign properties
# only on proc 0 because no spawn has happened yet
model.prepare_data()
if not self._is_data_prepared:
model.prepare_data()
self._is_data_prepared = True

# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
model.logger = self.logger # reset logger binding

# Run learning rate finder:
if self.auto_lr_find:
self._run_lr_finder_internally(model)
model.logger = self.logger # reset logger binding

# route to appropriate start method
# when using multi-node or DDP within a node start each module in a separate process
Expand Down
23 changes: 23 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,26 @@ def test_suggestion_with_non_finite_values(tmpdir):

assert before_lr == after_lr, \
'Learning rate was altered because of non-finite loss values'


def test_logger_reset_correctly(tmpdir):
""" Test that logger is updated correctly """
tutils.reset_seed()

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)

trainer = Trainer(
default_save_path=tmpdir,
max_epochs=10,
auto_lr_find=True
)
logger1 = trainer.logger
trainer.fit(model)
logger2 = trainer.logger
logger3 = model.logger

assert logger1 == logger2, \
'Learning rate finder altered the logger of trainer'
assert logger2 == logger3, \
'Learning rate finder altered the logger of model'
23 changes: 23 additions & 0 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,26 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):

with pytest.raises(MisconfigurationException):
trainer.fit(model, **fit_options)


def test_logger_reset_correctly(tmpdir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shall go to tests/loggers

""" Test that logger is updated correctly """
tutils.reset_seed()

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)

trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
auto_scale_batch_size=True
)
logger1 = trainer.logger
trainer.fit(model)
logger2 = trainer.logger
logger3 = model.logger

assert logger1 == logger2, \
'Batch size finder altered the logger of trainer'
assert logger2 == logger3, \
'Batch size finder altered the logger of model'