diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index 19a85b87949df..20fb1cae24732 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -245,7 +245,7 @@ def main(args: Namespace) -> None: ) if args.evaluate: - trainer.run_evaluation() + trainer.test() else: trainer.fit(model) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 16f68f1e13502..440a4ea4e6ac3 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -176,6 +176,7 @@ class TrainerEvaluationLoopMixin(ABC): use_tpu: bool reload_dataloaders_every_epoch: ... tpu_id: int + verbose_test: bool # Callback system on_validation_batch_start: Callable @@ -307,15 +308,16 @@ def _evaluate( self.on_validation_batch_end() # track outputs for collation - dl_outputs.append(output) + if output is not None: + dl_outputs.append(output) outputs.append(dl_outputs) - eval_results = {} + eval_results = outputs # with a single dataloader don't pass an array if len(dataloaders) == 1: - outputs = outputs[0] + eval_results = outputs[0] # give model a chance to do something with the outputs (and method defined) if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): @@ -324,22 +326,22 @@ def _evaluate( if test_mode: if self.is_overridden('test_end', model=model): # TODO: remove in v1.0.0 - eval_results = model.test_end(outputs) + eval_results = model.test_end(eval_results) rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.' ' Use `test_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('test_epoch_end', model=model): - eval_results = model.test_epoch_end(outputs) + eval_results = model.test_epoch_end(eval_results) else: if self.is_overridden('validation_end', model=model): # TODO: remove in v1.0.0 - eval_results = model.validation_end(outputs) + eval_results = model.validation_end(eval_results) rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.' ' Use `validation_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('validation_epoch_end', model=model): - eval_results = model.validation_epoch_end(outputs) + eval_results = model.validation_epoch_end(eval_results) # enable train mode again model.train() @@ -385,31 +387,40 @@ def run_evaluation(self, test_mode: bool = False): # enable disabling validation step with limit_val_batches = 0 should_skip = sum(max_batches) == 0 if should_skip: - return + return [], [] # run evaluation eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) # enable no returns - callback_metrics = {} + eval_loop_results = [] if eval_results is not None and len(eval_results) > 0: - _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results) - # add metrics to prog bar - self.add_progress_bar_metrics(prog_bar_metrics) + # in eval, the user may return something at every validation step without final reduction + if not isinstance(eval_results, list): + eval_results = [eval_results] + + for result in eval_results: + _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result) + + # add metrics to prog bar + self.add_progress_bar_metrics(prog_bar_metrics) + + # log results of test + if test_mode and self.is_global_zero and self.verbose_test: + print('-' * 80) + print('TEST RESULTS') + pprint(callback_metrics) + print('-' * 80) - # log results of test - if test_mode and self.is_global_zero: - print('-' * 80) - print('TEST RESULTS') - pprint(callback_metrics) - print('-' * 80) + # log metrics + self.log_metrics(log_metrics, {}) - # log metrics - self.log_metrics(log_metrics, {}) + # track metrics for callbacks + self.callback_metrics.update(callback_metrics) - # track metrics for callbacks - self.callback_metrics.update(callback_metrics) + if len(callback_metrics) > 0: + eval_loop_results.append(callback_metrics) # hook model.on_post_performance_check() @@ -429,7 +440,7 @@ def run_evaluation(self, test_mode: bool = False): else: self.on_validation_end() - return callback_metrics + return eval_loop_results, eval_results def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): # make dataloader_idx arg in validation_step optional diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 770dc4b314688..1f611ab7ac57c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -128,8 +128,9 @@ class Trainer( >>> trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0) >>> trainer.fit(model, train_loader) 1 - >>> trainer.test(model, train_loader) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - 1 + >>> test_outputs = trainer.test(model, train_loader, verbose=False) + >>> len(test_outputs) + 25 """ DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores') @@ -396,6 +397,9 @@ def __init__( self.test_dataloaders = None self.val_dataloaders = None + # when true, prints test results + self.verbose_test = True + # when .test() is called, it sets this self.tested_ckpt_path = None @@ -1125,7 +1129,6 @@ def run_pretrain_routine(self, model: LightningModule): if self.logger is not None: # save exp to get started self.logger.log_hyperparams(ref_model.hparams) - self.logger.save() if self.use_ddp or self.use_ddp2: @@ -1163,22 +1166,38 @@ def run_pretrain_routine(self, model: LightningModule): if self.testing: # only load test dataloader for testing # self.reset_test_dataloader(ref_model) - results = self.run_evaluation(test_mode=True) - - # remove all cuda tensors - if results is not None and isinstance(results, dict) and len(results) > 0: - for k, v in results.items(): - if isinstance(v, torch.Tensor): - results[k] = v.cpu().item() + eval_loop_results, _ = self.run_evaluation(test_mode=True) - return results - else: + if len(eval_loop_results) == 0: return 1 + # remove the tensors from the eval results + for i, result in enumerate(eval_loop_results): + if isinstance(result, dict): + for k, v in result.items(): + if isinstance(v, torch.Tensor): + result[k] = v.cpu().item() + + return eval_loop_results + # check if we should run validation during training self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ and not self.fast_dev_run + # run a few val batches before training starts + self._run_sanity_check(ref_model, model) + + # clear cache before training + if self.on_gpu and self.root_gpu is not None: + # use context because of: + # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 + with torch.cuda.device(f'cuda:{self.root_gpu}'): + torch.cuda.empty_cache() + + # CORE TRAINING LOOP + self.train() + + def _run_sanity_check(self, ref_model, model): # run tiny validation (if validation defined) # to make sure program won't crash during val if not self.disable_validation and self.num_sanity_val_steps > 0: @@ -1197,26 +1216,20 @@ def run_pretrain_routine(self, model: LightningModule): # allow no returns from eval if eval_results is not None and len(eval_results) > 0: + # when we get a list back, used only the last item + if isinstance(eval_results, list): + eval_results = eval_results[-1] _, _, _, callback_metrics, _ = self.process_output(eval_results) self.callback_metrics = callback_metrics self.on_sanity_check_end() - # clear cache before training - if self.on_gpu and self.root_gpu is not None: - # use context because of: - # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 - with torch.cuda.device(f'cuda:{self.root_gpu}'): - torch.cuda.empty_cache() - - # CORE TRAINING LOOP - self.train() - def test( self, model: Optional[LightningModule] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ckpt_path: Optional[str] = 'best' + ckpt_path: Optional[str] = 'best', + verbose: bool = True ): r""" @@ -1231,6 +1244,11 @@ def test( ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the weights from the last epoch to test. Default to ``best``. + verbose: If True, prints the test results + + Returns: + The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + Example:: # Option 1 @@ -1270,6 +1288,8 @@ def test( # -------------------- # SETUP HOOK # -------------------- + self.verbose_test = verbose + if self.global_rank != 0: return diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index c387997da57d7..a4988673c60a4 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -15,6 +15,10 @@ def __init__(self, weights=None): self.training_step_end_called = False self.training_epoch_end_called = False + self.validation_step_called = False + self.validation_step_end_called = False + self.validation_epoch_end_called = False + self.l1 = nn.Linear(2, 3, bias=False) if weights is None: weights = torch.tensor([ @@ -162,13 +166,61 @@ def training_epoch_end_dict(self, outputs): return {'log': logs, 'progress_bar': pbar} + def validation_step_no_return(self, batch, batch_idx): + self.validation_step_called = True + acc = self.step(batch, batch_idx) + + def validation_step_scalar_return(self, batch, batch_idx): + self.validation_step_called = True + acc = self.step(batch, batch_idx) + return acc + + def validation_step_arbitary_dict_return(self, batch, batch_idx): + self.validation_step_called = True + acc = self.step(batch, batch_idx) + return {'some': acc, 'value': 'a'} + def validation_step_dict_return(self, batch, batch_idx): + self.validation_step_called = True acc = self.step(batch, batch_idx) - logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} + logs = {'log_acc1': torch.tensor(12 + batch_idx).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)} pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)} return {'val_loss': acc, 'log': logs, 'progress_bar': pbar} + def validation_step_end_no_return(self, val_step_output): + assert len(val_step_output) == 3 + assert val_step_output['val_loss'] == 171 + assert val_step_output['log']['log_acc1'] >= 12 + assert val_step_output['progress_bar']['pbar_acc1'] == 17 + self.validation_step_end_called = True + + def validation_step_end(self, val_step_output): + assert len(val_step_output) == 3 + assert val_step_output['val_loss'] == 171 + assert val_step_output['log']['log_acc1'] >= 12 + assert val_step_output['progress_bar']['pbar_acc1'] == 17 + self.validation_step_end_called = True + + val_step_output['val_step_end'] = torch.tensor(1802) + + return val_step_output + + def validation_epoch_end(self, outputs): + assert len(outputs) == self.trainer.num_val_batches[0] + + for i, out in enumerate(outputs): + assert out['log']['log_acc1'] >= 12 + i + + self.validation_epoch_end_called = True + + result = outputs[-1] + result['val_epoch_end'] = torch.tensor(1233) + return result + + # ----------------------------- + # DATA + # ----------------------------- def train_dataloader(self): return DataLoader(DummyDataset(), batch_size=3, shuffle=False) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 9331d6c7a540f..244439f7634d7 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -52,7 +52,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): pretrained_model.cpu() # test we have good test accuracy - acc = results['test_acc'] + acc = results[0]['test_acc'] assert acc > 0.5, f"Model failed to get expected {0.5} accuracy. test_acc = {acc}" dataloaders = model.test_dataloader() @@ -102,7 +102,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): results = new_trainer.test(pretrained_model) pretrained_model.cpu() - acc = results['test_acc'] + acc = results[0]['test_acc'] assert acc > 0.5, f"Model failed to get expected {0.5} accuracy. test_acc = {acc}" dataloaders = model.test_dataloader() diff --git a/tests/models/test_test_loop.py b/tests/models/test_test_loop.py index 89103116bd8f3..c65809ad25221 100644 --- a/tests/models/test_test_loop.py +++ b/tests/models/test_test_loop.py @@ -21,12 +21,12 @@ def test_single_gpu_test(tmpdir): trainer.fit(model) assert 'ckpt' in trainer.checkpoint_callback.best_model_path results = trainer.test() - assert 'test_acc' in results + assert 'test_acc' in results[0] old_weights = model.c_d1.weight.clone().detach().cpu() results = trainer.test(model) - assert 'test_acc' in results + assert 'test_acc' in results[0] # make sure weights didn't change new_weights = model.c_d1.weight.clone().detach().cpu() @@ -50,12 +50,12 @@ def test_dp_test(tmpdir): trainer.fit(model) assert 'ckpt' in trainer.checkpoint_callback.best_model_path results = trainer.test() - assert 'test_acc' in results + assert 'test_acc' in results[0] old_weights = model.c_d1.weight.clone().detach().cpu() results = trainer.test(model) - assert 'test_acc' in results + assert 'test_acc' in results[0] # make sure weights didn't change new_weights = model.c_d1.weight.clone().detach().cpu() @@ -79,12 +79,12 @@ def test_ddp_spawn_test(tmpdir): trainer.fit(model) assert 'ckpt' in trainer.checkpoint_callback.best_model_path results = trainer.test() - assert 'test_acc' in results + assert 'test_acc' in results[0] old_weights = model.c_d1.weight.clone().detach().cpu() results = trainer.test(model) - assert 'test_acc' in results + assert 'test_acc' in results[0] # make sure weights didn't change new_weights = model.c_d1.weight.clone().detach().cpu() diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index e76ef0e556352..85b706e1dc9a4 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -295,7 +295,6 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim ] assert trainer.num_test_batches == expected_test_batches - @pytest.mark.parametrize( ['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [ diff --git a/tests/trainer/test_eval_loop_dict_return.py b/tests/trainer/test_eval_loop_dict_return.py new file mode 100644 index 0000000000000..d4e845badeb9b --- /dev/null +++ b/tests/trainer/test_eval_loop_dict_return.py @@ -0,0 +1,305 @@ +""" +Tests to ensure that the training loop works with a dict +""" +from pytorch_lightning import Trainer +from tests.base.deterministic_model import DeterministicModel + + +def test_validation_step_no_return(tmpdir): + """ + Test that val step can return nothing + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.validation_step = model.validation_step_no_return + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + out, eval_results = trainer.run_evaluation(test_mode=False) + assert len(out) == 0 + assert len(eval_results) == 0 + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_validation_step_scalar_return(tmpdir): + """ + Test that val step can return a scalar + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.validation_step = model.validation_step_scalar_return + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2 + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + out, eval_results = trainer.run_evaluation(test_mode=False) + assert len(out) == 0 + assert len(eval_results) == 2 + assert eval_results[0] == 171 and eval_results[1] == 171 + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_validation_step_arbitrary_dict_return(tmpdir): + """ + Test that val step can return an arbitrary dict + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.validation_step = model.validation_step_arbitary_dict_return + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2 + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + assert len(callback_metrics) == 2 + assert len(eval_results) == 2 + assert eval_results[0]['some'] == 171 + assert eval_results[1]['some'] == 171 + + assert eval_results[0]['value'] == 'a' + assert eval_results[1]['value'] == 'a' + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_validation_step_dict_return(tmpdir): + """ + Test that val step can return a dict with all the expected keys and they end up + in the correct place + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.validation_step = model.validation_step_dict_return + model.validation_step_end = None + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2 + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + assert len(callback_metrics) == 2 + assert len(callback_metrics[0]) == 5 + assert len(eval_results) == 2 + assert eval_results[0]['log']['log_acc1'] == 12 + assert eval_results[1]['log']['log_acc1'] == 13 + + for k in ['val_loss', 'log', 'progress_bar']: + assert k in eval_results[0] + assert k in eval_results[1] + + # ensure all the keys ended up as candidates for callbacks + assert len(trainer.callback_metrics) == 8 + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_val_step_step_end_no_return(tmpdir): + """ + Test that val step + val step end work (with no return in val step end) + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.validation_step = model.validation_step_dict_return + model.validation_step_end = model.validation_step_end_no_return + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2 + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + assert len(callback_metrics) == 0 + assert len(eval_results) == 0 + + # make sure correct steps were called + assert model.validation_step_called + assert model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_val_step_step_end(tmpdir): + """ + Test that val step + val step end work + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.validation_step = model.validation_step_dict_return + model.validation_step_end = model.validation_step_end + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2 + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + assert len(callback_metrics) == 2 + assert len(callback_metrics[0]) == 6 + + callback_metrics = callback_metrics[0] + assert callback_metrics['val_step_end'] == 1802 + assert len(eval_results) == 2 + assert eval_results[0]['log']['log_acc1'] == 12 + assert eval_results[1]['log']['log_acc1'] == 13 + + for k in ['val_loss', 'log', 'progress_bar']: + assert k in eval_results[0] + assert k in eval_results[1] + + # ensure all the keys ended up as candidates for callbacks + assert len(trainer.callback_metrics) == 9 + + # make sure correct steps were called + assert model.validation_step_called + assert model.validation_step_end_called + assert not model.validation_epoch_end_called + + +def test_no_val_step_end(tmpdir): + """ + Test that val step + val epoch end + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.validation_step = model.validation_step_dict_return + model.validation_step_end = None + model.validation_epoch_end = model.validation_epoch_end + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=3, + num_sanity_val_steps=0, + max_epochs=2 + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + assert len(callback_metrics) == 1 + assert len(callback_metrics[0]) == 6 + assert len(eval_results) == 1 + + eval_results = eval_results[0] + assert 'val_step_end' not in eval_results + assert eval_results['val_epoch_end'] == 1233 + + for k in ['val_loss', 'log', 'progress_bar']: + assert k in eval_results + + # ensure all the keys ended up as candidates for callbacks + assert len(trainer.callback_metrics) == 9 + + # make sure correct steps were called + assert model.validation_step_called + assert not model.validation_step_end_called + assert model.validation_epoch_end_called + + +def test_full_val_loop(tmpdir): + """ + Test that val step + val step end + val epoch end + """ + model = DeterministicModel() + model.training_step = model.training_step_dict_return + model.validation_step = model.validation_step_dict_return + model.validation_step_end = model.validation_step_end + model.validation_epoch_end = model.validation_epoch_end + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + limit_train_batches=2, + limit_val_batches=3, + num_sanity_val_steps=0, + max_epochs=2 + ) + trainer.fit(model) + + # out are the results of the full loop + # eval_results are output of _evaluate + callback_metrics, eval_results = trainer.run_evaluation(test_mode=False) + assert len(callback_metrics) == 1 + assert len(callback_metrics[0]) == 7 + assert len(eval_results) == 1 + + eval_results = eval_results[0] + assert eval_results['val_step_end'] == 1802 + assert eval_results['val_epoch_end'] == 1233 + + for k in ['val_loss', 'log', 'progress_bar']: + assert k in eval_results + + # ensure all the keys ended up as candidates for callbacks + assert len(trainer.callback_metrics) == 10 + + # make sure correct steps were called + assert model.validation_step_called + assert model.validation_step_end_called + assert model.validation_epoch_end_called