Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests for val loop flow #2605

Merged
merged 25 commits into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pl_examples/domain_templates/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def main(args: Namespace) -> None:
)

if args.evaluate:
trainer.run_evaluation()
trainer.test()
else:
trainer.fit(model)

Expand Down
57 changes: 34 additions & 23 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class TrainerEvaluationLoopMixin(ABC):
use_tpu: bool
reload_dataloaders_every_epoch: ...
tpu_id: int
verbose_test: bool

# Callback system
on_validation_batch_start: Callable
Expand Down Expand Up @@ -307,15 +308,16 @@ def _evaluate(
self.on_validation_batch_end()

# track outputs for collation
dl_outputs.append(output)
if output is not None:
dl_outputs.append(output)

outputs.append(dl_outputs)

eval_results = {}
eval_results = outputs

# with a single dataloader don't pass an array
if len(dataloaders) == 1:
outputs = outputs[0]
eval_results = outputs[0]

# give model a chance to do something with the outputs (and method defined)
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
Expand All @@ -324,22 +326,22 @@ def _evaluate(
if test_mode:
if self.is_overridden('test_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.test_end(outputs)
eval_results = model.test_end(eval_results)
rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.'
' Use `test_epoch_end` instead.', DeprecationWarning)

elif self.is_overridden('test_epoch_end', model=model):
eval_results = model.test_epoch_end(outputs)
eval_results = model.test_epoch_end(eval_results)

else:
if self.is_overridden('validation_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.validation_end(outputs)
eval_results = model.validation_end(eval_results)
rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.'
' Use `validation_epoch_end` instead.', DeprecationWarning)

elif self.is_overridden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)
eval_results = model.validation_epoch_end(eval_results)

# enable train mode again
model.train()
Expand Down Expand Up @@ -385,31 +387,40 @@ def run_evaluation(self, test_mode: bool = False):
# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
return
return [], []

# run evaluation
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)

# enable no returns
callback_metrics = {}
eval_loop_results = []
if eval_results is not None and len(eval_results) > 0:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)

# add metrics to prog bar
self.add_progress_bar_metrics(prog_bar_metrics)
# in eval, the user may return something at every validation step without final reduction
if not isinstance(eval_results, list):
eval_results = [eval_results]

for result in eval_results:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result)

# add metrics to prog bar
self.add_progress_bar_metrics(prog_bar_metrics)

# log results of test
if test_mode and self.is_global_zero and self.verbose_test:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)
print('-' * 80)

# log results of test
if test_mode and self.is_global_zero:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)
print('-' * 80)
# log metrics
self.log_metrics(log_metrics, {})

# log metrics
self.log_metrics(log_metrics, {})
# track metrics for callbacks
self.callback_metrics.update(callback_metrics)

# track metrics for callbacks
self.callback_metrics.update(callback_metrics)
if len(callback_metrics) > 0:
eval_loop_results.append(callback_metrics)

# hook
model.on_post_performance_check()
Expand All @@ -429,7 +440,7 @@ def run_evaluation(self, test_mode: bool = False):
else:
self.on_validation_end()

return callback_metrics
return eval_loop_results, eval_results

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
Expand Down
66 changes: 43 additions & 23 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ class Trainer(
>>> trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0)
>>> trainer.fit(model, train_loader)
1
>>> trainer.test(model, train_loader) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
1
>>> test_outputs = trainer.test(model, train_loader, verbose=False)
>>> len(test_outputs)
25
"""
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores')

Expand Down Expand Up @@ -396,6 +397,9 @@ def __init__(
self.test_dataloaders = None
self.val_dataloaders = None

# when true, prints test results
self.verbose_test = True

# when .test() is called, it sets this
self.tested_ckpt_path = None

Expand Down Expand Up @@ -1125,7 +1129,6 @@ def run_pretrain_routine(self, model: LightningModule):
if self.logger is not None:
# save exp to get started
self.logger.log_hyperparams(ref_model.hparams)

self.logger.save()

if self.use_ddp or self.use_ddp2:
Expand Down Expand Up @@ -1163,22 +1166,38 @@ def run_pretrain_routine(self, model: LightningModule):
if self.testing:
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
results = self.run_evaluation(test_mode=True)

# remove all cuda tensors
if results is not None and isinstance(results, dict) and len(results) > 0:
for k, v in results.items():
if isinstance(v, torch.Tensor):
results[k] = v.cpu().item()
eval_loop_results, _ = self.run_evaluation(test_mode=True)

return results
else:
if len(eval_loop_results) == 0:
return 1

# remove the tensors from the eval results
for i, result in enumerate(eval_loop_results):
if isinstance(result, dict):
for k, v in result.items():
if isinstance(v, torch.Tensor):
result[k] = v.cpu().item()

return eval_loop_results

# check if we should run validation during training
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
and not self.fast_dev_run

# run a few val batches before training starts
self._run_sanity_check(ref_model, model)

# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# CORE TRAINING LOOP
self.train()

def _run_sanity_check(self, ref_model, model):
# run tiny validation (if validation defined)
# to make sure program won't crash during val
if not self.disable_validation and self.num_sanity_val_steps > 0:
Expand All @@ -1197,26 +1216,20 @@ def run_pretrain_routine(self, model: LightningModule):

# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
# when we get a list back, used only the last item
if isinstance(eval_results, list):
eval_results = eval_results[-1]
_, _, _, callback_metrics, _ = self.process_output(eval_results)
self.callback_metrics = callback_metrics

self.on_sanity_check_end()

# clear cache before training
if self.on_gpu and self.root_gpu is not None:
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# CORE TRAINING LOOP
self.train()

def test(
self,
model: Optional[LightningModule] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best'
ckpt_path: Optional[str] = 'best',
verbose: bool = True
):
r"""

Expand All @@ -1231,6 +1244,11 @@ def test(
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the weights from the last epoch to test. Default to ``best``.

verbose: If True, prints the test results

Returns:
The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries

Example::

# Option 1
Expand Down Expand Up @@ -1270,6 +1288,8 @@ def test(
# --------------------
# SETUP HOOK
# --------------------
self.verbose_test = verbose

if self.global_rank != 0:
return

Expand Down
54 changes: 53 additions & 1 deletion tests/base/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def __init__(self, weights=None):
self.training_step_end_called = False
self.training_epoch_end_called = False

self.validation_step_called = False
self.validation_step_end_called = False
self.validation_epoch_end_called = False

self.l1 = nn.Linear(2, 3, bias=False)
if weights is None:
weights = torch.tensor([
Expand Down Expand Up @@ -162,13 +166,61 @@ def training_epoch_end_dict(self, outputs):

return {'log': logs, 'progress_bar': pbar}

def validation_step_no_return(self, batch, batch_idx):
self.validation_step_called = True
acc = self.step(batch, batch_idx)

def validation_step_scalar_return(self, batch, batch_idx):
self.validation_step_called = True
acc = self.step(batch, batch_idx)
return acc

def validation_step_arbitary_dict_return(self, batch, batch_idx):
self.validation_step_called = True
acc = self.step(batch, batch_idx)
return {'some': acc, 'value': 'a'}

def validation_step_dict_return(self, batch, batch_idx):
self.validation_step_called = True
acc = self.step(batch, batch_idx)

logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)}
logs = {'log_acc1': torch.tensor(12 + batch_idx).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)}
pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)}
return {'val_loss': acc, 'log': logs, 'progress_bar': pbar}

def validation_step_end_no_return(self, val_step_output):
assert len(val_step_output) == 3
assert val_step_output['val_loss'] == 171
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
assert val_step_output['log']['log_acc1'] >= 12
assert val_step_output['progress_bar']['pbar_acc1'] == 17
self.validation_step_end_called = True

def validation_step_end(self, val_step_output):
assert len(val_step_output) == 3
assert val_step_output['val_loss'] == 171
assert val_step_output['log']['log_acc1'] >= 12
assert val_step_output['progress_bar']['pbar_acc1'] == 17
self.validation_step_end_called = True

val_step_output['val_step_end'] = torch.tensor(1802)

return val_step_output

def validation_epoch_end(self, outputs):
assert len(outputs) == self.trainer.num_val_batches[0]

for i, out in enumerate(outputs):
assert out['log']['log_acc1'] >= 12 + i

self.validation_epoch_end_called = True

result = outputs[-1]
result['val_epoch_end'] = torch.tensor(1233)
return result

# -----------------------------
# DATA
# -----------------------------
def train_dataloader(self):
return DataLoader(DummyDataset(), batch_size=3, shuffle=False)

Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
pretrained_model.cpu()

# test we have good test accuracy
acc = results['test_acc']
acc = results[0]['test_acc']
assert acc > 0.5, f"Model failed to get expected {0.5} accuracy. test_acc = {acc}"

dataloaders = model.test_dataloader()
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
results = new_trainer.test(pretrained_model)
pretrained_model.cpu()

acc = results['test_acc']
acc = results[0]['test_acc']
assert acc > 0.5, f"Model failed to get expected {0.5} accuracy. test_acc = {acc}"

dataloaders = model.test_dataloader()
Expand Down
12 changes: 6 additions & 6 deletions tests/models/test_test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def test_single_gpu_test(tmpdir):
trainer.fit(model)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.test()
assert 'test_acc' in results
assert 'test_acc' in results[0]

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results
assert 'test_acc' in results[0]

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
Expand All @@ -50,12 +50,12 @@ def test_dp_test(tmpdir):
trainer.fit(model)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.test()
assert 'test_acc' in results
assert 'test_acc' in results[0]

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results
assert 'test_acc' in results[0]

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
Expand All @@ -79,12 +79,12 @@ def test_ddp_spawn_test(tmpdir):
trainer.fit(model)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.test()
assert 'test_acc' in results
assert 'test_acc' in results[0]

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results
assert 'test_acc' in results[0]

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
Expand Down
Loading