Skip to content

Commit

Permalink
Clean up dataloader logic (#926)
Browse files Browse the repository at this point in the history
* added get dataloaders directly using a getter

* deleted decorator

* added prepare_data hook

* refactored dataloader init

* refactored dataloader init

* added dataloader reset flag and main loop

* added dataloader reset flag and main loop

* added dataloader reset flag and main loop

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixes #909

* fixes #909

* bug fix

* Fixes #902
  • Loading branch information
williamFalcon committed Feb 25, 2020
1 parent c56ee8b commit 1015a00
Show file tree
Hide file tree
Showing 18 changed files with 615 additions and 359 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added automatic sampler setup. Depending on DDP or TPU, lightning configures the sampler correctly (user needs to do nothing) ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926))
- Added `reload_dataloaders_every_epoch=False` flag for trainer. Some users require reloading data every epoch ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926))
- Added `progress_bar_refresh_rate=50` flag for trainer. Throttle refresh rate on notebooks ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926))
- Updated governance docs
- Added a check to ensure that the metric used for early stopping exists before training commences ([#542](https://github.com/PyTorchLightning/pytorch-lightning/pull/542))
- Added `optimizer_idx` argument to `backward` hook ([#733](https://github.com/PyTorchLightning/pytorch-lightning/pull/733))
Expand All @@ -22,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Removed `@data_loader` decorator ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926))
- Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752))
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
Expand Down
6 changes: 3 additions & 3 deletions docs/source/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ Training set-up
- init_optimizers
- configure_apex
- configure_ddp
- get_train_dataloader
- get_test_dataloaders
- get_val_dataloaders
- train_dataloader
- test_dataloaders
- val_dataloaders
- summarize
- restore_weights

Expand Down
20 changes: 9 additions & 11 deletions pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,37 +192,35 @@ def __dataloader(self, train):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=train,
transform=transform, download=True)
transform=transform, download=False)

# when using multi-node (ddp) we need to add the datasampler
train_sampler = None
batch_size = self.hparams.batch_size

if self.use_ddp:
train_sampler = DistributedSampler(dataset)

should_shuffle = train_sampler is None
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=should_shuffle,
sampler=train_sampler,
num_workers=0
)

return loader

@pl.data_loader
def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=True,
transform=transform, download=True)
dataset = MNIST(root=self.hparams.data_root, train=False,
transform=transform, download=True)

def train_dataloader(self):
log.info('Training data loader called.')
return self.__dataloader(train=True)

@pl.data_loader
def val_dataloader(self):
log.info('Validation data loader called.')
return self.__dataloader(train=False)

@pl.data_loader
def test_dataloader(self):
log.info('Test data loader called.')
return self.__dataloader(train=False)
Expand Down
29 changes: 6 additions & 23 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import traceback
from functools import wraps
import warnings


def data_loader(fn):
Expand All @@ -8,27 +9,9 @@ def data_loader(fn):
:param fn:
:return:
"""
wraps(fn)
attr_name = '_lazy_' + fn.__name__
@wraps(fn)
def _get_data_loader(self):
try:
value = getattr(self, attr_name)
except AttributeError:
try:
value = fn(self) # Lazy evaluation, done only once.
if (
value is not None and
not isinstance(value, list) and
fn.__name__ in ['test_dataloader', 'val_dataloader']
):
value = [value]
except AttributeError as e:
# Guard against AttributeError suppression. (Issue #142)
traceback.print_exc()
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
raise RuntimeError(error) from e
setattr(self, attr_name, value) # Memoize evaluation.
return value
w = 'data_loader decorator deprecated in 0.6.1. Will remove 0.8.0'
warnings.warn(w)

return _get_data_loader
def inner_fx(self):
return fn(self)
return inner_fx
41 changes: 37 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel

try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True

except ImportError:
XLA_AVAILABLE = False


class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):

Expand Down Expand Up @@ -798,7 +805,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, sec
optimizer.zero_grad()
"""
if isinstance(optimizer, torch.optim.LBFGS):
if self.trainer.use_tpu and XLA_AVAILABLE:
xm.optimizer_step(optimizer)
elif isinstance(optimizer, torch.optim.LBFGS):
optimizer.step(second_order_closure)
else:
optimizer.step()
Expand Down Expand Up @@ -868,7 +877,33 @@ def tbptt_split_batch(self, batch, split_size):

return splits

@data_loader
def prepare_data(self):
"""Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once
:return: PyTorch DataLoader
This is called before requesting the dataloaders
.. code-block:: python
model.prepare_data()
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
Example
-------
.. code-block:: python
def prepare_data(self):
download_imagenet()
clean_imagenet()
cache_imagenet()
"""
return None

def train_dataloader(self):
"""Implement a PyTorch DataLoader
Expand Down Expand Up @@ -908,7 +943,6 @@ def tng_dataloader(self): # todo: remove in v0.8.0
" and this method will be removed in v0.8.0", DeprecationWarning)
return output

@data_loader
def test_dataloader(self):
r"""
Expand Down Expand Up @@ -942,7 +976,6 @@ def test_dataloader(self):
"""
return None

@data_loader
def val_dataloader(self):
r"""
Expand Down
Loading

0 comments on commit 1015a00

Please sign in to comment.