Skip to content

Commit

Permalink
Fixed val interval (#405)
Browse files Browse the repository at this point in the history
* added fixed frequency val batch check

* added fixed frequency val batch check

* Finished IterableDataset support

* flake8

* flake8

* flake8
  • Loading branch information
williamFalcon authored Oct 22, 2019
1 parent ab67944 commit 792ad00
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 7 deletions.
9 changes: 8 additions & 1 deletion docs/Trainer/Validation loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,20 @@ trainer = Trainer(test_percent_check=0.1)

---
#### Set validation check frequency within 1 training epoch
For large datasets it's often desirable to check validation multiple times within a training loop
For large datasets it's often desirable to check validation multiple times within a training loop.
Pass in a float to check that often within 1 training epoch.
Pass in an int k to check every k training batches. Must use an int if using
an IterableDataset.

``` {.python}
# DEFAULT
trainer = Trainer(val_check_interval=0.95)
# check every .25 of an epoch
trainer = Trainer(val_check_interval=0.25)
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
trainer = Trainer(val_check_interval=100)
```

---
Expand Down
29 changes: 25 additions & 4 deletions pytorch_lightning/trainer/data_loading_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.utils.data import IterableDataset

from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
from apex import amp
Expand All @@ -15,8 +18,11 @@ class TrainerDataLoadingMixin(object):
def layout_bookeeping(self):

# determine number of training batches
self.nb_training_batches = len(self.get_train_dataloader())
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
if isinstance(self.get_train_dataloader(), IterableDataset):
self.nb_training_batches = float('inf')
else:
self.nb_training_batches = len(self.get_train_dataloader())
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)

# determine number of validation batches
# val datasets could be none, 1 or 2+
Expand All @@ -34,8 +40,13 @@ def layout_bookeeping(self):
self.nb_test_batches = max(1, self.nb_test_batches)

# determine when to check validation
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
else:
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

def get_dataloaders(self, model):
"""
Expand Down Expand Up @@ -127,6 +138,16 @@ def get_dataloaders(self, model):
self.get_test_dataloaders()
self.get_val_dataloaders()

# support IterableDataset for train data
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset)

This comment has been minimized.

Copy link
@ehhuang

ehhuang Oct 23, 2019

Shouldn't this be self.get_train_dataloader().dataset?

if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for train_dataloader,
Trainer(val_check_interval) must be an int.
An int k specifies checking validation every k training batches
'''
raise MisconfigurationException('when using ')

def determine_data_use_amount(self, train_percent_check, val_percent_check,
test_percent_check, overfit_pct):
"""
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/train_loop_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def train(self):

# init progress_bar when requested
if self.show_progress_bar:
self.progress_bar.reset(self.total_batches)
nb_iterations = self.total_batches

# for iterable train loader, the progress bar never ends
if self.is_iterable_train_dataloader:
nb_iterations = float('inf')
self.progress_bar.reset(nb_iterations)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self,
:param train_percent_check: int. How much of train set to check
:param val_percent_check: int. How much of val set to check
:param test_percent_check: int. How much of test set to check
:param val_check_interval: int. Check val this frequently within a train epoch
:param val_check_interval: float/int. If float, % of tng epoch. If int, check every n batch
:param log_save_interval: int. Writes logs to disk this often
:param row_log_interval: int. How often to add logging rows
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
Expand Down Expand Up @@ -160,6 +160,7 @@ def __init__(self,
self.get_train_dataloader = None
self.get_test_dataloaders = None
self.get_val_dataloaders = None
self.is_iterable_train_dataloader = False

# training state
self.model = None
Expand Down

0 comments on commit 792ad00

Please sign in to comment.