diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ed0abbcfa801..022e3159e0cb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed setup call while testing ([#2624](https://github.com/PyTorchLightning/pytorch-lightning/pull/2624)) + - Fixed Horovod backend to scale LR schedlers with the optimizer ([#2626](https://github.com/PyTorchLightning/pytorch-lightning/pull/2626)) - Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index d07f47df44673..38cc71a4f2195 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -509,8 +509,8 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0): model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks) # call setup after the ddp process has connected - self.setup('fit') - if self.is_function_implemented('setup', model): + if not self.testing: + self.setup('fit') model.setup('fit') # on world_size=0 let everyone know training is starting diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index b8ff9a03e75a5..bf03514bc2c5a 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -167,8 +167,8 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device): def single_gpu_train(self, model): # call setup - self.setup('fit') - if self.is_function_implemented('setup', model): + if not self.testing: + self.setup('fit') model.setup('fit') model.cuda(self.root_gpu) @@ -189,8 +189,8 @@ def single_gpu_train(self, model): def tpu_train(self, tpu_core_idx, model): # call setup after the ddp process has connected - self.setup('fit') - if self.is_function_implemented('setup', model): + if not self.testing: + self.setup('fit') model.setup('fit') # put model on tpu @@ -229,8 +229,8 @@ def tpu_train(self, tpu_core_idx, model): def dp_train(self, model): # call setup after the ddp process has connected - self.setup('fit') - if self.is_function_implemented('setup', model): + if not self.testing: + self.setup('fit') model.setup('fit') model.cuda(self.root_gpu) @@ -275,8 +275,8 @@ def dp_train(self, model): def horovod_train(self, model): # call setup after the ddp process has connected - self.setup('fit') - if self.is_function_implemented('setup', model): + if not self.testing: + self.setup('fit') model.setup('fit') if torch.cuda.is_available() and self.on_gpu: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 71a32397b491b..671958d729802 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1087,8 +1087,8 @@ def fit( raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # call setup after the ddp process has connected - self.setup('fit') - if self.is_function_implemented('setup', model): + if not self.testing: + self.setup('fit') model.setup('fit') # CHOOSE OPTIMIZER @@ -1381,8 +1381,7 @@ def test( def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() - if self.is_function_implemented('setup', model): - model.setup('test') + model.setup('test') # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: @@ -1429,8 +1428,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): def __test_given_model(self, model, test_dataloaders): # setup hook - if self.is_function_implemented('setup', model): - model.setup('test') + model.setup('test') # attach data if test_dataloaders is not None: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d21767ab29409..06039e76b8277 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -980,3 +980,34 @@ def test_trainer_pickle(tmpdir): ) pickle.dumps(trainer) cloudpickle.dumps(trainer) + + +def test_trainer_setup_call(tmpdir): + """Test setup call with fit and test call.""" + + class CurrentModel(EvalModelTemplate): + + def setup(self, stage): + self.stage = stage + + class TrainerSubclass(Trainer): + + def setup(self, stage): + self.stage = stage + + model = CurrentModel() + + # fit model + trainer = TrainerSubclass( + default_root_dir=tmpdir, + max_epochs=1, + checkpoint_callback=False + ) + + trainer.fit(model) + assert trainer.stage == 'fit' + assert trainer.get_model().stage == 'fit' + + trainer.test(ckpt_path=None) + assert trainer.stage == 'test' + assert trainer.get_model().stage == 'test'