Skip to content

Commit

Permalink
removed callback metrics from test results obj (#2994)
Browse files Browse the repository at this point in the history
* removed callback metrics from test results obj

* removed callback metrics from test results obj
  • Loading branch information
williamFalcon committed Aug 16, 2020
1 parent 766d0f3 commit d702d4d
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,10 @@ def __log_evaluation_epoch_metrics(self, eval_results, test_mode):
prog_bar_metrics = result.epoch_pbar_metrics
log_metrics = result.epoch_log_metrics
callback_metrics = result.callback_metrics

# in testing we don't need the callback metrics
if test_mode:
callback_metrics = {}
else:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result)

Expand Down
38 changes: 38 additions & 0 deletions tests/base/model_test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,44 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
})
return output

def test_step_result_obj(self, batch, batch_idx, *args, **kwargs):
"""
Default, baseline test_step
:param batch:
:return:
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)

loss_test = self.loss(y, y_hat)

# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
test_acc = torch.tensor(test_acc)

test_acc = test_acc.type_as(x)

result = EvalResult()
# alternate possible outputs to test
if batch_idx % 1 == 0:
result.log_dict({
'test_loss': loss_test,
'test_acc': test_acc,
})
return result
if batch_idx % 2 == 0:
return test_acc

if batch_idx % 3 == 0:
result.log_dict({
'test_loss': loss_test,
'test_acc': test_acc,
})
result.test_dic = {'test_loss_a': loss_test}
return result

def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""
Default, baseline test_step
Expand Down
19 changes: 19 additions & 0 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@ def training_step(self, batch, batch_idx, optimizer_idx=None):
)
return output

def training_step_result_obj(self, batch, batch_idx, optimizer_idx=None):
# forward pass
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)

# calculate loss
loss_val = self.loss(y, y_hat)
log_val = loss_val

# alternate between tensors and scalars for "log" and "progress_bar"
if batch_idx % 2 == 0:
log_val = log_val.item()

result = TrainResult(loss_val)
result.log('some_val', log_val * log_val, prog_bar=True, logger=False)
result.log('train_some_val', log_val * log_val)
return result

def training_step__inf_loss(self, batch, batch_idx, optimizer_idx=None):
output = self.training_step(batch, batch_idx, optimizer_idx)
if batch_idx == self.test_step_inf_loss:
Expand Down
20 changes: 20 additions & 0 deletions tests/base/model_valid_steps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC
from collections import OrderedDict
from pytorch_lightning.core.step_result import EvalResult

import torch

Expand Down Expand Up @@ -32,6 +33,25 @@ def validation_step(self, batch, batch_idx, *args, **kwargs):
})
return output

def validation_step_result_obj(self, batch, batch_idx, *args, **kwargs):
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)

loss_val = self.loss(y, y_hat)

# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc).type_as(x)

result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val)
result.log_dict({
'val_loss': loss_val,
'val_acc': val_acc,
})
return result

def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs):
"""
Lightning calls this inside the validation loop
Expand Down
42 changes: 42 additions & 0 deletions tests/trainer/test_validation_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
from tests.base.deterministic_model import DeterministicModel
from pytorch_lightning import seed_everything


# test with train_step_end
Expand Down Expand Up @@ -435,3 +436,44 @@ def test_val_step_full_loop_result_dp(tmpdir):
assert 'epoch_test_step_metric' in seen_keys
assert 'test_step_end_metric' in seen_keys
assert 'test_epoch_end_metric' in seen_keys


def test_full_loop_result_cpu(tmpdir):
seed_everything(1234)
os.environ['PL_DEV_DEBUG'] = '1'

batches = 5
epochs = 2

model = EvalModelTemplate()
model.training_step = model.training_step_result_obj
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj
model.validation_step_end = None
model.validation_epoch_end = None
model.test_step = model.test_step_result_obj
model.test_step_end = None
model.test_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
row_log_interval=2,
limit_train_batches=batches,
weights_summary=None,
)

trainer.fit(model)

results = trainer.test()

# assert we returned all metrics requested
assert len(results) == 1
results = results[0]
assert results['test_loss'] < 0.3
assert results['test_acc'] > 0.9
assert len(results) == 2
assert 'val_early_stop_on' not in results
assert 'val_checkpoint_on' not in results

0 comments on commit d702d4d

Please sign in to comment.