Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed val interval #405

Merged
merged 7 commits into from
Oct 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/Trainer/Validation loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,20 @@ trainer = Trainer(test_percent_check=0.1)

---
#### Set validation check frequency within 1 training epoch
For large datasets it's often desirable to check validation multiple times within a training loop
For large datasets it's often desirable to check validation multiple times within a training loop.
Pass in a float to check that often within 1 training epoch.
Pass in an int k to check every k training batches. Must use an int if using
an IterableDataset.

``` {.python}
# DEFAULT
trainer = Trainer(val_check_interval=0.95)

# check every .25 of an epoch
trainer = Trainer(val_check_interval=0.25)

# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
trainer = Trainer(val_check_interval=100)
```

---
Expand Down
29 changes: 25 additions & 4 deletions pytorch_lightning/trainer/data_loading_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.utils.data import IterableDataset

from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
from apex import amp
Expand All @@ -15,8 +18,11 @@ class TrainerDataLoadingMixin(object):
def layout_bookeeping(self):

# determine number of training batches
self.nb_training_batches = len(self.get_train_dataloader())
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)
if isinstance(self.get_train_dataloader(), IterableDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this needs to be isinstance(self.get_train_dataloader().dataset, IterableDataset)

self.nb_training_batches = float('inf')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and tqdm crashes because of the float('inf')

else:
self.nb_training_batches = len(self.get_train_dataloader())
self.nb_training_batches = int(self.nb_training_batches * self.train_percent_check)

# determine number of validation batches
# val datasets could be none, 1 or 2+
Expand All @@ -34,8 +40,13 @@ def layout_bookeeping(self):
self.nb_test_batches = max(1, self.nb_test_batches)

# determine when to check validation
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
else:
self.val_check_batch = int(self.nb_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

def get_dataloaders(self, model):
"""
Expand Down Expand Up @@ -127,6 +138,16 @@ def get_dataloaders(self, model):
self.get_test_dataloaders()
self.get_val_dataloaders()

# support IterableDataset for train data
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset)
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for train_dataloader,
Trainer(val_check_interval) must be an int.
An int k specifies checking validation every k training batches
'''
raise MisconfigurationException('when using ')

def determine_data_use_amount(self, train_percent_check, val_percent_check,
test_percent_check, overfit_pct):
"""
Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/trainer/train_loop_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def train(self):

# init progress_bar when requested
if self.show_progress_bar:
self.progress_bar.reset(self.total_batches)
nb_iterations = self.total_batches

# for iterable train loader, the progress bar never ends
if self.is_iterable_train_dataloader:
nb_iterations = float('inf')
self.progress_bar.reset(nb_iterations)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self,
:param train_percent_check: int. How much of train set to check
:param val_percent_check: int. How much of val set to check
:param test_percent_check: int. How much of test set to check
:param val_check_interval: int. Check val this frequently within a train epoch
:param val_check_interval: float/int. If float, % of tng epoch. If int, check every n batch
:param log_save_interval: int. Writes logs to disk this often
:param row_log_interval: int. How often to add logging rows
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
Expand Down Expand Up @@ -160,6 +160,7 @@ def __init__(self,
self.get_train_dataloader = None
self.get_test_dataloaders = None
self.get_val_dataloaders = None
self.is_iterable_train_dataloader = False

# training state
self.model = None
Expand Down