diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 3b07b81dae2c18..d7f5f3eb0fd693 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -608,6 +608,10 @@ def __log_evaluation_epoch_metrics(self, eval_results, test_mode): prog_bar_metrics = result.epoch_pbar_metrics log_metrics = result.epoch_log_metrics callback_metrics = result.callback_metrics + + # in testing we don't need the callback metrics + if test_mode: + callback_metrics = {} else: _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result) diff --git a/tests/base/model_test_steps.py b/tests/base/model_test_steps.py index 496ede49f0150a..db5ad1ed33ef4d 100644 --- a/tests/base/model_test_steps.py +++ b/tests/base/model_test_steps.py @@ -49,6 +49,44 @@ def test_step(self, batch, batch_idx, *args, **kwargs): }) return output + def test_step_result_obj(self, batch, batch_idx, *args, **kwargs): + """ + Default, baseline test_step + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x) + + loss_test = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + test_acc = torch.tensor(test_acc) + + test_acc = test_acc.type_as(x) + + result = EvalResult() + # alternate possible outputs to test + if batch_idx % 1 == 0: + result.log_dict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + }) + return result + if batch_idx % 2 == 0: + return test_acc + + if batch_idx % 3 == 0: + result.log_dict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + }) + result.test_dic = {'test_loss_a': loss_test} + return result + def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs): """ Default, baseline test_step diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 16d05680c94dc8..b6fc748037af3f 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -38,6 +38,25 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): ) return output + def training_step_result_obj(self, batch, batch_idx, optimizer_idx=None): + # forward pass + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + log_val = loss_val + + # alternate between tensors and scalars for "log" and "progress_bar" + if batch_idx % 2 == 0: + log_val = log_val.item() + + result = TrainResult(loss_val) + result.log('some_val', log_val * log_val, prog_bar=True, logger=False) + result.log('train_some_val', log_val * log_val) + return result + def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None): output = self.training_step(batch, batch_idx, optimizer_idx) if batch_idx == self.test_step_inf_loss: diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py index b93ef43e35f1d2..74694d365545c9 100644 --- a/tests/base/model_valid_steps.py +++ b/tests/base/model_valid_steps.py @@ -1,5 +1,6 @@ from abc import ABC from collections import OrderedDict +from pytorch_lightning.core.step_result import EvalResult import torch @@ -32,6 +33,25 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): }) return output + def validation_step_result_obj(self, batch, batch_idx, *args, **kwargs): + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc).type_as(x) + + result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val) + result.log_dict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + }) + return result + def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs): """ Lightning calls this inside the validation loop diff --git a/tests/trainer/test_validation_steps_result_return.py b/tests/trainer/test_validation_steps_result_return.py index 28f012535d1ab7..d6295840c13494 100644 --- a/tests/trainer/test_validation_steps_result_return.py +++ b/tests/trainer/test_validation_steps_result_return.py @@ -9,6 +9,7 @@ from pytorch_lightning import Trainer from tests.base import EvalModelTemplate from tests.base.deterministic_model import DeterministicModel +from pytorch_lightning import seed_everything # test with train_step_end @@ -435,3 +436,44 @@ def test_val_step_full_loop_result_dp(tmpdir): assert 'epoch_test_step_metric' in seen_keys assert 'test_step_end_metric' in seen_keys assert 'test_epoch_end_metric' in seen_keys + + +def test_full_loop_result_cpu(tmpdir): + seed_everything(1234) + os.environ['PL_DEV_DEBUG'] = '1' + + batches = 5 + epochs = 2 + + model = EvalModelTemplate() + model.training_step = model.training_step_result_obj + model.training_step_end = None + model.training_epoch_end = None + model.validation_step = model.validation_step_result_obj + model.validation_step_end = None + model.validation_epoch_end = None + model.test_step = model.test_step_result_obj + model.test_step_end = None + model.test_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=epochs, + early_stop_callback=True, + row_log_interval=2, + limit_train_batches=batches, + weights_summary=None, + ) + + trainer.fit(model) + + results = trainer.test() + + # assert we returned all metrics requested + assert len(results) == 1 + results = results[0] + assert results['test_loss'] < 0.3 + assert results['test_acc'] > 0.9 + assert len(results) == 2 + assert 'val_early_stop_on' not in results + assert 'val_checkpoint_on' not in results