Skip to content

Commit

Permalink
fix setup call while testing (#2624)
Browse files Browse the repository at this point in the history
* fix setup call while testing

* changelog

* drop if condition

* add test to check setup call

* flake8

* update test to check model stage

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
rohitgr7 and williamFalcon committed Jul 25, 2020
1 parent 8599b67 commit cb0c6ad
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed setup call while testing ([#2624](https://github.com/PyTorchLightning/pytorch-lightning/pull/2624))

- Fixed Horovod backend to scale LR schedlers with the optimizer ([#2626](https://github.com/PyTorchLightning/pytorch-lightning/pull/2626))

- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,8 @@ def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks)

# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

# on world_size=0 let everyone know training is starting
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):

def single_gpu_train(self, model):
# call setup
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

model.cuda(self.root_gpu)
Expand All @@ -189,8 +189,8 @@ def single_gpu_train(self, model):

def tpu_train(self, tpu_core_idx, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

# put model on tpu
Expand Down Expand Up @@ -229,8 +229,8 @@ def tpu_train(self, tpu_core_idx, model):

def dp_train(self, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

model.cuda(self.root_gpu)
Expand Down Expand Up @@ -275,8 +275,8 @@ def dp_train(self, model):

def horovod_train(self, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

if torch.cuda.is_available() and self.on_gpu:
Expand Down
10 changes: 4 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,8 @@ def fit(
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')

# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
if not self.testing:
self.setup('fit')
model.setup('fit')

# CHOOSE OPTIMIZER
Expand Down Expand Up @@ -1381,8 +1381,7 @@ def test(

def __test_using_best_weights(self, ckpt_path, test_dataloaders):
model = self.get_model()
if self.is_function_implemented('setup', model):
model.setup('test')
model.setup('test')

# if user requests the best checkpoint but we don't have it, error
if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
Expand Down Expand Up @@ -1429,8 +1428,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):

def __test_given_model(self, model, test_dataloaders):
# setup hook
if self.is_function_implemented('setup', model):
model.setup('test')
model.setup('test')

# attach data
if test_dataloaders is not None:
Expand Down
31 changes: 31 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,3 +980,34 @@ def test_trainer_pickle(tmpdir):
)
pickle.dumps(trainer)
cloudpickle.dumps(trainer)


def test_trainer_setup_call(tmpdir):
"""Test setup call with fit and test call."""

class CurrentModel(EvalModelTemplate):

def setup(self, stage):
self.stage = stage

class TrainerSubclass(Trainer):

def setup(self, stage):
self.stage = stage

model = CurrentModel()

# fit model
trainer = TrainerSubclass(
default_root_dir=tmpdir,
max_epochs=1,
checkpoint_callback=False
)

trainer.fit(model)
assert trainer.stage == 'fit'
assert trainer.get_model().stage == 'fit'

trainer.test(ckpt_path=None)
assert trainer.stage == 'test'
assert trainer.get_model().stage == 'test'

0 comments on commit cb0c6ad

Please sign in to comment.