Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for IterableDatasets everywhere #1104

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,22 @@ def reset_val_dataloader(self, model):
# determine number of validation batches
# val datasets could be none, 1 or 2+
if self.val_dataloaders is not None:
self._percent_range_check('val_percent_check')
for dataloader in self.val_dataloaders:
if self.is_infinite_dataloader(dataloader):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
self.num_val_batches = float('inf')
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders)
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
if self.num_val_batches != float('inf'):
self._percent_range_check('val_percent_check')

self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders)
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
elif not (self.val_percent_check == 1.0 or self.val_percent_check == 0.0):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
m = '''
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
does not implement `__len__`) for `val_dataloader`, `Trainer(val_percent_check)`
must be `0.0` or `1.0`.
'''
raise MisconfigurationException(m)

def reset_test_dataloader(self, model):
"""Dataloaders are provided by the model.
Expand All @@ -188,11 +200,23 @@ def reset_test_dataloader(self, model):

# determine number of test batches
if self.test_dataloaders is not None:
self._percent_range_check('test_percent_check')
for dataloader in self.test_dataloaders:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
if self.is_infinite_dataloader(dataloader):
self.num_test_batches = float('inf')

len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders)
self.num_test_batches = len_sum
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
if self.num_test_batches != float('inf'):
self._percent_range_check('test_percent_check')

len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders)
self.num_test_batches = len_sum
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
elif not (self.test_percent_check == 1.0 or self.test_percent_check == 0.0):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
m = '''
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
does not implement `__len__`) for `test_dataloader`, `Trainer(test_percent_check)`
must be `0.0` or `1.0`.
'''
raise MisconfigurationException(m)

def request_data_loader(self, data_loader_fx):
"""
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,9 @@ def run_evaluation(self, test_mode: bool = False):
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test_mode)
desc = 'Testing' if test_mode else 'Validating'
pbar = tqdm(desc=desc, total=max_batches, leave=test_mode, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
file=sys.stdout)
total = max_batches if max_batches != float('inf') else None
pbar = tqdm(desc=desc, total=total, leave=test_mode, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True, file=sys.stdout)
setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar)

# run evaluation
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def train(self):

total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
if not self.disable_validation and self.num_training_batches != float('inf'):
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
Expand All @@ -324,8 +324,8 @@ def train(self):
if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_infinite_dataloader(self.train_dataloader):
# for infinite train loader, the progress bar never ends
elif self.total_batches == float('inf'):
# for infinite train or val loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches
Expand All @@ -334,7 +334,7 @@ def train(self):
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
desc = f'Epoch {epoch + 1}'
self.main_progress_bar.set_description(desc)

# -----------------
Expand Down
3 changes: 3 additions & 0 deletions tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
LightValStepFitMultipleDataloadersMixin,
LightTrainDataloader,
LightTestDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
Expand Down
42 changes: 42 additions & 0 deletions tests/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,48 @@ def test_dataloader(self):
return self._dataloader(train=False)


class CustomInfDataloader:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather create complet Dataloader so it is easier to undestand... what about?

class CustomInfDataloader:

    def __init__(self, dataset, batch_size, shuffle):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        idxs = []
        while True:
            if len(idxs) < self.batch_size:
                idxs = range(len(self.dataset))
                if self.shuffle:
                    np.random.shuffle(idxs)
            batch = [self.dataset[idx] for idx in idxs[:self.batch_size]]
            yield batch
            idxs = idxs[len(batch):]

Copy link
Member Author

@ethanwharris ethanwharris Mar 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.DataLoader does quite a bit more than this (e.g. collate functions, samplers, etc.) so it is probably better to wrap it rather than rewrite it - also we don't really have access to the dataset when this is created, only the dataloader

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda we generally want to avoid duplicating torch functionality. Otherwise the project scope will blow up quickly,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree, I just found this construction quite difficult to follow...

def __init__(self, dataloader):
self.dataloader = dataloader
self.iter = iter(dataloader)
self.count = 0

def __iter__(self):
self.count = 0
return self
Borda marked this conversation as resolved.
Show resolved Hide resolved

def __next__(self):
if self.count >= 50:
raise StopIteration
self.count = self.count + 1
try:
return next(self.iter)
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)


class LightInfTrainDataloader:
"""Simple test dataloader."""

def train_dataloader(self):
return CustomInfDataloader(self._dataloader(train=True))


class LightInfValDataloader:
"""Simple test dataloader."""

def val_dataloader(self):
return CustomInfDataloader(self._dataloader(train=False))


class LightInfTestDataloader:
"""Simple test dataloader."""

def test_dataloader(self):
return CustomInfDataloader(self._dataloader(train=False))


class LightEmptyTestStep:
"""Empty test step."""

Expand Down
99 changes: 75 additions & 24 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
LightValStepFitMultipleDataloadersMixin,
LightValStepFitSingleDataloaderMixin,
LightTrainDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader
)
from pytorch_lightning.utilities.debugging import MisconfigurationException

Expand Down Expand Up @@ -274,31 +277,45 @@ def test_inf_train_dataloader(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
tutils.reset_seed()

class CurrentTestModel(LightningTestModel):
def train_dataloader(self):
dataloader = self._dataloader(train=True)
class CurrentTestModel(
LightInfTrainDataloader,
LightningTestModel
):
pass

class CustomInfDataLoader:
def __init__(self, dataloader):
self.dataloader = dataloader
self.iter = iter(dataloader)
self.count = 0
hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

def __iter__(self):
self.count = 0
return self
# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=0.5
)
trainer.fit(model)

def __next__(self):
if self.count >= 5:
raise StopIteration
self.count = self.count + 1
try:
return next(self.iter)
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=50
)
result = trainer.fit(model)

# verify training completed
assert result == 1

return CustomInfDataLoader(dataloader)

def test_inf_val_dataloader(tmpdir):
"""Test inf val data loader (e.g. IterableDataset)"""
tutils.reset_seed()

class CurrentTestModel(
LightInfValDataloader,
LightningTestModel
):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)
Expand All @@ -308,17 +325,51 @@ def __next__(self):
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=0.5
val_percent_check=0.5
)
trainer.fit(model)

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
val_check_interval=50,
max_epochs=1
)
result = trainer.fit(model)

# verify training completed
assert result == 1


def test_inf_test_dataloader(tmpdir):
"""Test inf test data loader (e.g. IterableDataset)"""
tutils.reset_seed()

class CurrentTestModel(
LightInfTestDataloader,
LightningTestModel,
LightTestFitSingleTestDataloadersMixin
):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
test_percent_check=0.5
)
trainer.test(model)

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1
)
result = trainer.fit(model)
trainer.test(model)

# verify training completed
assert result == 1