Skip to content

Commit

Permalink
tests for val loop flow (#2605)
Browse files Browse the repository at this point in the history
* add tests for single scalar return from training

* add tests for single scalar return from training

* add tests for single scalar return from training

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only

* fixing val step only
  • Loading branch information
williamFalcon committed Jul 14, 2020
1 parent 548dbd1 commit aaa1553
Show file tree
Hide file tree
Showing 8 changed files with 444 additions and 57 deletions.
2 changes: 1 addition & 1 deletion pl_examples/domain_templates/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def main(args: Namespace) -> None:
)

if args.evaluate:
trainer.run_evaluation()
trainer.test()
else:
trainer.fit(model)

Expand Down
57 changes: 34 additions & 23 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class TrainerEvaluationLoopMixin(ABC):
use_tpu: bool
reload_dataloaders_every_epoch: ...
tpu_id: int
verbose_test: bool

# Callback system
on_validation_batch_start: Callable
Expand Down Expand Up @@ -307,15 +308,16 @@ def _evaluate(
self.on_validation_batch_end()

# track outputs for collation
dl_outputs.append(output)
if output is not None:
dl_outputs.append(output)

outputs.append(dl_outputs)

eval_results = {}
eval_results = outputs

# with a single dataloader don't pass an array
if len(dataloaders) == 1:
outputs = outputs[0]
eval_results = outputs[0]

# give model a chance to do something with the outputs (and method defined)
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
Expand All @@ -324,22 +326,22 @@ def _evaluate(
if test_mode:
if self.is_overridden('test_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.test_end(outputs)
eval_results = model.test_end(eval_results)
rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.'
' Use `test_epoch_end` instead.', DeprecationWarning)

elif self.is_overridden('test_epoch_end', model=model):
eval_results = model.test_epoch_end(outputs)
eval_results = model.test_epoch_end(eval_results)

else:
if self.is_overridden('validation_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.validation_end(outputs)
eval_results = model.validation_end(eval_results)
rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.'
' Use `validation_epoch_end` instead.', DeprecationWarning)

elif self.is_overridden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)
eval_results = model.validation_epoch_end(eval_results)

# enable train mode again
model.train()
Expand Down Expand Up @@ -385,31 +387,40 @@ def run_evaluation(self, test_mode: bool = False):
# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
return
return [], []

# run evaluation
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)

# enable no returns
callback_metrics = {}
eval_loop_results = []
if eval_results is not None and len(eval_results) > 0:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)

# add metrics to prog bar
self.add_progress_bar_metrics(prog_bar_metrics)
# in eval, the user may return something at every validation step without final reduction
if not isinstance(eval_results, list):
eval_results = [eval_results]

for result in eval_results:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result)

# add metrics to prog bar
self.add_progress_bar_metrics(prog_bar_metrics)

# log results of test
if test_mode and self.is_global_zero and self.verbose_test:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)
print('-' * 80)

# log results of test
if test_mode and self.is_global_zero:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)
print('-' * 80)
# log metrics
self.log_metrics(log_metrics, {})

# log metrics
self.log_metrics(log_metrics, {})
# track metrics for callbacks
self.callback_metrics.update(callback_metrics)

# track metrics for callbacks
self.callback_metrics.update(callback_metrics)
if len(callback_metrics) > 0:
eval_loop_results.append(callback_metrics)

# hook
model.on_post_performance_check()
Expand All @@ -429,7 +440,7 @@ def run_evaluation(self, test_mode: bool = False):
else:
self.on_validation_end()

return callback_metrics
return eval_loop_results, eval_results

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
Expand Down
66 changes: 43 additions & 23 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ class Trainer(
>>> trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0)
>>> trainer.fit(model, train_loader)
1
>>> trainer.test(model, train_loader) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
1
>>> test_outputs = trainer.test(model, train_loader, verbose=False)
>>> len(test_outputs)
25
"""
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores')

Expand Down Expand Up @@ -396,6 +397,9 @@ def __init__(
self.test_dataloaders = None
self.val_dataloaders = None

# when true, prints test results
self.verbose_test = True

# when .test() is called, it sets this
self.tested_ckpt_path = None

Expand Down Expand Up @@ -1125,7 +1129,6 @@ def run_pretrain_routine(self, model: LightningModule):
if self.logger is not None:
# save exp to get started
self.logger.log_hyperparams(ref_model.hparams)

self.logger.save()

if self.use_ddp or self.use_ddp2:
Expand Down Expand Up @@ -1163,22 +1166,38 @@ def run_pretrain_routine(self, model: LightningModule):
if self.testing:
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
results = self.run_evaluation(test_mode=True)

# remove all cuda tensors
if results is not None and isinstance(results, dict) and len(results) > 0:
for k, v in results.items():
if isinstance(v, torch.Tensor):
results[k] = v.cpu().item()
eval_loop_results, _ = self.run_evaluation(test_mode=True)

return results
else:
if len(eval_loop_results) == 0:
return 1

# remove the tensors from the eval results
for i, result in enumerate(eval_loop_results):
if isinstance(result, dict):
for k, v in result.items():
if isinstance(v, torch.Tensor):
result[k] = v.cpu().item()

return eval_loop_results

# check if we should run validation during training
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
and not self.fast_dev_run

# run a few val batches before training starts
self._run_sanity_check(ref_model, model)

# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# CORE TRAINING LOOP
self.train()

def _run_sanity_check(self, ref_model, model):
# run tiny validation (if validation defined)
# to make sure program won't crash during val
if not self.disable_validation and self.num_sanity_val_steps > 0:
Expand All @@ -1197,26 +1216,20 @@ def run_pretrain_routine(self, model: LightningModule):

# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
# when we get a list back, used only the last item
if isinstance(eval_results, list):
eval_results = eval_results[-1]
_, _, _, callback_metrics, _ = self.process_output(eval_results)
self.callback_metrics = callback_metrics

self.on_sanity_check_end()

# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# CORE TRAINING LOOP
self.train()

def test(
self,
model: Optional[LightningModule] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best'
ckpt_path: Optional[str] = 'best',
verbose: bool = True
):
r"""
Expand All @@ -1231,6 +1244,11 @@ def test(
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the weights from the last epoch to test. Default to ``best``.
verbose: If True, prints the test results
Returns:
The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
Example::
# Option 1
Expand Down Expand Up @@ -1270,6 +1288,8 @@ def test(
# --------------------
# SETUP HOOK
# --------------------
self.verbose_test = verbose

if self.global_rank != 0:
return

Expand Down
54 changes: 53 additions & 1 deletion tests/base/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def __init__(self, weights=None):
self.training_step_end_called = False
self.training_epoch_end_called = False

self.validation_step_called = False
self.validation_step_end_called = False
self.validation_epoch_end_called = False

self.l1 = nn.Linear(2, 3, bias=False)
if weights is None:
weights = torch.tensor([
Expand Down Expand Up @@ -162,13 +166,61 @@ def training_epoch_end_dict(self, outputs):

return {'log': logs, 'progress_bar': pbar}

def validation_step_no_return(self, batch, batch_idx):
self.validation_step_called = True
acc = self.step(batch, batch_idx)

def validation_step_scalar_return(self, batch, batch_idx):
self.validation_step_called = True
acc = self.step(batch, batch_idx)
return acc

def validation_step_arbitary_dict_return(self, batch, batch_idx):
self.validation_step_called = True
acc = self.step(batch, batch_idx)
return {'some': acc, 'value': 'a'}

def validation_step_dict_return(self, batch, batch_idx):
self.validation_step_called = True
acc = self.step(batch, batch_idx)

logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)}
logs = {'log_acc1': torch.tensor(12 + batch_idx).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)}
pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)}
return {'val_loss': acc, 'log': logs, 'progress_bar': pbar}

def validation_step_end_no_return(self, val_step_output):
assert len(val_step_output) == 3
assert val_step_output['val_loss'] == 171
assert val_step_output['log']['log_acc1'] >= 12
assert val_step_output['progress_bar']['pbar_acc1'] == 17
self.validation_step_end_called = True

def validation_step_end(self, val_step_output):
assert len(val_step_output) == 3
assert val_step_output['val_loss'] == 171
assert val_step_output['log']['log_acc1'] >= 12
assert val_step_output['progress_bar']['pbar_acc1'] == 17
self.validation_step_end_called = True

val_step_output['val_step_end'] = torch.tensor(1802)

return val_step_output

def validation_epoch_end(self, outputs):
assert len(outputs) == self.trainer.num_val_batches[0]

for i, out in enumerate(outputs):
assert out['log']['log_acc1'] >= 12 + i

self.validation_epoch_end_called = True

result = outputs[-1]
result['val_epoch_end'] = torch.tensor(1233)
return result

# -----------------------------
# DATA
# -----------------------------
def train_dataloader(self):
return DataLoader(DummyDataset(), batch_size=3, shuffle=False)

Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
pretrained_model.cpu()

# test we have good test accuracy
acc = results['test_acc']
acc = results[0]['test_acc']
assert acc > 0.5, f"Model failed to get expected {0.5} accuracy. test_acc = {acc}"

dataloaders = model.test_dataloader()
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
results = new_trainer.test(pretrained_model)
pretrained_model.cpu()

acc = results['test_acc']
acc = results[0]['test_acc']
assert acc > 0.5, f"Model failed to get expected {0.5} accuracy. test_acc = {acc}"

dataloaders = model.test_dataloader()
Expand Down
12 changes: 6 additions & 6 deletions tests/models/test_test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def test_single_gpu_test(tmpdir):
trainer.fit(model)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.test()
assert 'test_acc' in results
assert 'test_acc' in results[0]

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results
assert 'test_acc' in results[0]

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
Expand All @@ -50,12 +50,12 @@ def test_dp_test(tmpdir):
trainer.fit(model)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.test()
assert 'test_acc' in results
assert 'test_acc' in results[0]

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results
assert 'test_acc' in results[0]

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
Expand All @@ -79,12 +79,12 @@ def test_ddp_spawn_test(tmpdir):
trainer.fit(model)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.test()
assert 'test_acc' in results
assert 'test_acc' in results[0]

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.test(model)
assert 'test_acc' in results
assert 'test_acc' in results[0]

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()
Expand Down
Loading

0 comments on commit aaa1553

Please sign in to comment.