Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix setup call while testing #2624

Merged
merged 7 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed setup call while testing ([#2624](https://github.com/PyTorchLightning/pytorch-lightning/pull/2624))

- Fixed Horovod backend to scale LR schedlers with the optimizer ([#2626](https://github.com/PyTorchLightning/pytorch-lightning/pull/2626))

- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks)

# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

# on world_size=0 let everyone know training is starting
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):

def single_gpu_train(self, model):
# call setup
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

model.cuda(self.root_gpu)
Expand All @@ -189,8 +189,8 @@ def single_gpu_train(self, model):

def tpu_train(self, tpu_core_idx, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

# put model on tpu
Expand Down Expand Up @@ -229,8 +229,8 @@ def tpu_train(self, tpu_core_idx, model):

def dp_train(self, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

model.cuda(self.root_gpu)
Expand Down Expand Up @@ -275,8 +275,8 @@ def dp_train(self, model):

def horovod_train(self, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

if torch.cuda.is_available() and self.on_gpu:
Expand Down
10 changes: 4 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,8 +1054,8 @@ def fit(
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')

# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

# CHOOSE OPTIMIZER
Expand Down Expand Up @@ -1328,8 +1328,7 @@ def test(

def __test_using_best_weights(self, ckpt_path, test_dataloaders):
model = self.get_model()
if self.is_function_implemented('setup', model):
model.setup('test')
model.setup('test')
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

# if user requests the best checkpoint but we don't have it, error
if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
Expand Down Expand Up @@ -1373,8 +1372,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):

def __test_given_model(self, model, test_dataloaders):
# setup hook
if self.is_function_implemented('setup', model):
model.setup('test')
model.setup('test')

# attach data
if test_dataloaders is not None:
Expand Down
31 changes: 31 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,3 +980,34 @@ def test_trainer_pickle(tmpdir):
)
pickle.dumps(trainer)
cloudpickle.dumps(trainer)


def test_trainer_setup_call(tmpdir):
"""Test setup call with fit and test call."""

class CurrentModel(EvalModelTemplate):

def setup(self, stage):
self.stage = stage

class TrainerSubclass(Trainer):

def setup(self, stage):
self.stage = stage

model = CurrentModel()

# fit model
trainer = TrainerSubclass(
default_root_dir=tmpdir,
max_epochs=1,
checkpoint_callback=False
)

trainer.fit(model)
assert trainer.stage == 'fit'
assert trainer.get_model().stage == 'fit'

trainer.test(ckpt_path=None)
assert trainer.stage == 'test'
assert trainer.get_model().stage == 'test'