Skip to content

Commit

Permalink
Fix max_batches with fast_dev_run. (#2581)
Browse files Browse the repository at this point in the history
* Fix fast_dev_run to run for all val_dataloaders

* fast_dev_run check

* changelog

* explicit

* limit_batches with fast_dev_run in init

* add test

* whitespace and comment fix

* comment and assertion

* added tests

* Fix fast_dev_run to run for all val_dataloaders

* fast_dev_run check

* changelog

* explicit

* limit_batches with fast_dev_run in init

* add test

* whitespace and comment fix

* comment and assertion

* added tests

* added tests

* added tests

* added tests

* update rtol

* Revert "update rtol"

This reverts commit 4320329.

* added tests

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
rohitgr7 and williamFalcon committed Jul 27, 2020
1 parent 26afcaa commit 84c507c
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))

- Fixed `fast_dev_run` to run for all dataloaders ([#2581](https://github.com/PyTorchLightning/pytorch-lightning/pull/2581))

- Fixed `save_dir` in loggers getting ignored by default value of `weights_save_path` when user did not specify `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))
Expand Down
19 changes: 5 additions & 14 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def total_train_batches(self) -> int:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
training dataloader is of infinite size.
"""
total_train_batches = 1 if self.trainer.fast_dev_run else self.trainer.num_training_batches
return total_train_batches
return self.trainer.num_training_batches

@property
def total_val_batches(self) -> int:
Expand All @@ -98,13 +97,10 @@ def total_val_batches(self) -> int:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
validation dataloader is of infinite size.
"""
trainer = self.trainer
total_val_batches = 0
if trainer.fast_dev_run and trainer.val_dataloaders is not None:
total_val_batches = len(trainer.val_dataloaders)
elif self.trainer.enable_validation:
is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0
total_val_batches = sum(trainer.num_val_batches) if is_val_epoch else 0
if not self.trainer.disable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
return total_val_batches

@property
Expand All @@ -114,12 +110,7 @@ def total_test_batches(self) -> int:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the
test dataloader is of infinite size.
"""
if self.trainer.fast_dev_run:
total_test_batches = len(self.trainer.test_dataloaders)
else:
total_test_batches = self.trainer.num_test_batches
total_test_batches = sum(total_test_batches)
return total_test_batches
return sum(self.trainer.num_test_batches)

def disable(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def on_train_end(self, trainer, pl_module):
# default used by the Trainer
trainer = Trainer(fast_dev_run=False)
# runs 1 train, val, test batch and program ends
# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)
gpus
Expand Down
10 changes: 4 additions & 6 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _evaluate(
if batch is None:
continue

# stop short when on fast_dev_run (sets max_batch=1)
# stop short when running on limited batches
if batch_idx >= dl_max_batches:
break

Expand Down Expand Up @@ -350,6 +350,9 @@ def _evaluate(

self.__eval_add_step_metrics(output)

# track debug metrics
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)

outputs.append(dl_outputs)

# ---------------------
Expand Down Expand Up @@ -513,14 +516,9 @@ def run_evaluation(self, test_mode: bool = False):
dataloaders = self.val_dataloaders
max_batches = self.num_val_batches

# enable fast_dev_run without val loop
if dataloaders is None:
return [], []

# cap max batches to 1 when using fast_dev_run
if self.fast_dev_run:
max_batches = [1]

# Validation/Test begin callbacks
if test_mode:
self.on_test_start()
Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
check_val_every_n_epoch: Check val every n train epochs.
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
Expand Down Expand Up @@ -505,7 +505,11 @@ def __init__(
self.max_steps = max_steps
self.min_steps = min_steps

self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
if num_sanity_val_steps == -1:
self.num_sanity_val_steps = float("inf")
else:
self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches)

# Backward compatibility, TODO: remove in v0.9.0
if print_nan_grads:
rank_zero_warn(
Expand All @@ -528,6 +532,9 @@ def __init__(

self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
self.num_sanity_val_steps = 0
self.max_epochs = 1
rank_zero_info(
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def train(self):
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

if self.should_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
if (met_min_epochs and met_min_steps):
self.run_training_teardown()
return
else:
Expand Down Expand Up @@ -507,7 +507,7 @@ def run_training_epoch(self):
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self.should_check_val(batch_idx, is_last_batch)
if self.fast_dev_run or should_check_val:
if should_check_val:
self.run_evaluation(test_mode=False)

# -----------------------------------------
Expand All @@ -530,7 +530,7 @@ def run_training_epoch(self):
# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if self.fast_dev_run or self.should_stop:
if self.should_stop:
break

# let ddp devices catch up when using horovod
Expand All @@ -548,7 +548,7 @@ def run_training_epoch(self):
def check_checkpoint_callback(self, should_check_val):
# when no val loop is present or fast-dev-run still need to call checkpoints
# TODO bake this logic into the checkpoint callback
should_activate = not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val)
should_activate = not self.is_overridden('validation_step') and not should_check_val
if should_activate:
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
[c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]
Expand Down
39 changes: 39 additions & 0 deletions pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections import Counter


class InternalDebugger(object):
Expand All @@ -10,6 +11,8 @@ def __init__(self, trainer):
self.logged_metrics = []
self.pbar_added_metrics = []
self.saved_losses = []
self.saved_val_losses = []
self.saved_test_losses = []
self.early_stopping_history = []
self.checkpoint_callback_history = []

Expand All @@ -23,6 +26,21 @@ def track_train_loss_history(self, batch_idx, loss):
loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()}
self.saved_losses.append(loss_dict)

def track_eval_loss_history(self, test_mode, batch_idx, dataloader_idx, output):
if self.enabled:
loss_dict = {
'sanity_check': self.trainer.running_sanity_check,
'dataloader_idx': dataloader_idx,
'batch_idx': batch_idx,
'epoch': self.trainer.current_epoch,
'output': output
}

if test_mode:
self.saved_test_losses.append(loss_dict)
else:
self.saved_val_losses.append(loss_dict)

def track_pbar_metrics_history(self, metrics):
if self.enabled:
metrics['debug_epoch'] = self.trainer.current_epoch
Expand Down Expand Up @@ -52,3 +70,24 @@ def track_checkpointing_history(self, filepath):
'filepath': filepath
}
self.checkpoint_callback_history.append(debug_dict)

@property
def num_seen_sanity_check_batches(self):
count = len([x for x in self.saved_val_losses if x['sanity_check']])
return count

@property
def num_seen_val_check_batches(self):
counts = Counter()
for x in self.saved_val_losses:
if not x['sanity_check']:
counts.update({x['dataloader_idx']: 1})
return counts

@property
def num_seen_test_check_batches(self):
counts = Counter()
for x in self.saved_test_losses:
if not x['sanity_check']:
counts.update({x['dataloader_idx']: 1})
return counts
2 changes: 2 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def test_progress_bar_fast_dev_run(tmpdir):
fast_dev_run=True,
)

trainer.fit(model)

progress_bar = trainer.progress_bar_callback
assert 1 == progress_bar.total_train_batches
# total val batches are known only after val dataloaders have reloaded
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_grad_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def on_after_backward(self):


@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
def test_grad_tracking(tmpdir, norm_type, rtol=1e-2):
os.environ['PL_DEV_DEBUG'] = '1'

# rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above
Expand Down
64 changes: 63 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import platform
from unittest.mock import patch

Expand Down Expand Up @@ -306,6 +307,8 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
Expand All @@ -323,16 +326,75 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v
limit_test_batches=limit_test_batches,
)
trainer.fit(model)

# -------------------------------------------
# MAKE SURE THE TRAINER SET THE CORRECT VALUES
# -------------------------------------------
assert trainer.num_training_batches == limit_train_batches
assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders)
trainer.test(ckpt_path=None)

# when the limit is greater than the number of test batches it should be the num in loaders
test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
if limit_test_batches > 1e10:
assert trainer.num_test_batches == [len(x) for x in model.test_dataloader()]
assert trainer.num_test_batches == test_dataloader_lengths
else:
assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders)

# -------------------------------------------
# make sure we actually saw the expected num of batches
# -------------------------------------------
num_val_dataloaders = len(model.val_dataloader())
num_test_dataloaders = len(model.test_dataloader())
if limit_train_batches > 0:

# make sure val batches are as expected
assert len(trainer.dev_debugger.num_seen_val_check_batches) == num_val_dataloaders
for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_val_check_batches.items():
assert num_batches == limit_val_batches

# make sure test batches are as expected
assert len(trainer.dev_debugger.num_seen_test_check_batches) == num_test_dataloaders
for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_test_check_batches.items():
if limit_test_batches > 1e10:
assert num_batches == test_dataloader_lengths[dataloader_idx]
else:
assert num_batches == limit_test_batches


def test_dataloaders_with_fast_dev_run(tmpdir):
"""Verify num_batches for train, val & test dataloaders passed with fast_dev_run = True"""
os.environ['PL_DEV_DEBUG'] = '1'

model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

# train, multiple val and multiple test dataloaders passed with fast_dev_run = True
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
fast_dev_run=True,
)
assert trainer.max_epochs == 1
assert trainer.num_sanity_val_steps == 0

trainer.fit(model)
assert not trainer.disable_validation
assert trainer.num_training_batches == 1
assert trainer.num_val_batches == [1] * len(trainer.val_dataloaders)

trainer.test(ckpt_path=None)
assert trainer.num_test_batches == [1] * len(trainer.test_dataloaders)

# verify sanity check batches match as expected
num_val_dataloaders = len(model.val_dataloader())
assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
Expand Down

0 comments on commit 84c507c

Please sign in to comment.