Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched iterative dataloading disables validation #2429

Closed
Uroc327 opened this issue Jun 30, 2020 · 9 comments · Fixed by #2437
Closed

Batched iterative dataloading disables validation #2429

Uroc327 opened this issue Jun 30, 2020 · 9 comments · Fixed by #2437
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@Uroc327
Copy link

Uroc327 commented Jun 30, 2020

🐛 Bug

Setting the batch_size parameter for torch.utils.data.DataLoader to a number greater than 1, prevents validation_step and validation_epoch_end from being called.

To Reproduce

Steps to reproduce the behavior:

  1. Run python main.py with bs = 1
  2. Observe exception raised in validation_step
  3. Run python main.py after changing to bs = 2
  4. Observe the model train successfully

Code sample

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset

class Dataset(IterableDataset):
    def __init__(self):
        super().__init__()

    def __iter__(self):
        for _ in range(1024):
            yield torch.randn(20)

    def __len__(self):
        return 1024

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fst = nn.Linear(20, 1)
        self.snd = nn.Linear(1, 20)

    def forward(self, x):
        x = self.fst(x)
        x = F.relu(x)
        x = self.snd(x)
        return x

    def training_step(self, batch, batchIdx):
        x = self.forward(batch)
        return {'loss': F.mse_loss(x, batch)}

    def validation_step(self, batch, batchIdx):
        raise NotImplementedError()
        x = self.forward(batch)
        return {'val_loss': F.mse_loss(x, batch)}

    def validation_epoch_end(self, outputs):
        return {'val_loss': torch.mean(torch.stack([x['val_loss'] for x in outputs]))}

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters())

if __name__ == '__main__':
    trainer = pl.Trainer(num_sanity_val_steps=0)
    net = Model()
    dataset = Dataset()

    bs = 2
    trainer.fit(net, train_dataloader=DataLoader(dataset, batch_size=bs), val_dataloaders=DataLoader(dataset, batch_size=bs))
> python main.py # with bs=2, should fail
GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | fst  | Linear | 21    
1 | snd  | Linear | 40    
/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:25: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:25: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 1000:  25%|████████████████████████████████████████████████████████████▌                                                                                                                                                                                     | 512/2048 [00:00<00:01, 1396.66it/s, loss=0.890, v_num=20]

Expected behavior

> python main.py # with bs=1
GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name | Type   | Params
--------------------------------
0 | fst  | Linear | 21    
1 | snd  | Linear | 40    
/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:25: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:25: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 1:  50%|████████████████Traceback (most recent call last):████████████████████████████████████████████████████████████████████████                                                                                                                          | 1024/2048 [00:00<00:00, 1337.66it/s, loss=1.032, v_num=19]
  File "main.py", line 57, in <module>
    trainer.fit(net, train_dataloader=DataLoader(dataset, batch_size=bs), val_dataloaders=DataLoader(dataset, batch_size=bs))
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 954, in fit
    self.run_pretrain_routine(model)
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1093, in run_pretrain_routine
    self.train()
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 375, in train
    self.run_training_epoch()
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 490, in run_training_epoch
    self.run_evaluation(test_mode=self.testing)
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 379, in run_evaluation
    eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 281, in _evaluate
    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 452, in evaluation_forward
    output = model.validation_step(*args)
  File "main.py", line 39, in validation_step
    raise NotImplementedError()
NotImplementedError
Exception ignored in: <object repr() failed>
Traceback (most recent call last):
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/tqdm/std.py", line 1086, in __del__
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/tqdm/std.py", line 1293, in close
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/tqdm/std.py", line 1471, in display
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/tqdm/std.py", line 1089, in __repr__
  File "/home/constantin/.virtualenvs/tensor/lib64/python3.6/site-packages/tqdm/std.py", line 1433, in format_dict
TypeError: 'NoneType' object is not iterable

Environment

Collecting environment information...
PyTorch version: 1.5.1+cu101
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Gentoo Base System release 2.7
GCC version: (Gentoo 9.3.0 p1) 9.3.0
CMake version: version 3.17.3

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: GeForce GT 730
Nvidia driver version: 440.82
cuDNN version: /opt/cuda/targets/x86_64-linux/lib/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.19.0
[pip3] pytorch-lightning==0.8.1
[pip3] torch==1.5.1+cu101
[pip3] torchvision==0.6.1+cu101
[conda] Could not collect

Additional context

Basically a reopen of #2351, as this issue is not fixed by changing batch size and dataset size.

@Uroc327 Uroc327 added bug Something isn't working help wanted Open to be worked on labels Jun 30, 2020
@Uroc327
Copy link
Author

Uroc327 commented Jun 30, 2020

Same behavior with pytorch-lightning==0.8.3.

@awaelchli
Copy link
Member

awaelchli commented Jun 30, 2020

I can reproduce. It is caused by pl.Trainer(num_sanity_val_steps=0)
For num_sanity_val_steps>0 it works fine

@awaelchli awaelchli added the priority: 0 High priority task label Jun 30, 2020
@awaelchli
Copy link
Member

awaelchli commented Jun 30, 2020

trying to fix his right now. another observation: only happens with iterable dataset.
EDIT: iterable dataset that has also length defined (see comment below)

@Uroc327
Copy link
Author

Uroc327 commented Jul 1, 2020

@awaelchli thanks for looking into this! If I remember correctly, then for num_sanity_val_steps > 0 it only raised the exception during sanity checks on my machine. When I try to raise the exception only in actual training (for example by raising the exception on the fourth time validation_step is called), then this completes sucessfully (i.e. does not call validation) as well.

@awaelchli
Copy link
Member

I think the issue is that your dataset is of type Iterable but has __len__ defined. PL interprets that wrongly. In fact, there is a weird issue with dataloaders from iterable datasets that have len defined, check out this colab that I made:
https://colab.research.google.com/drive/1RQyNLGORe4vOL_RS6khcFNObcxyFxtEm?usp=sharing
isn't it strange?

if I remove the len from the dataset definition, your code sample works.

@awaelchli awaelchli self-assigned this Jul 1, 2020
@awaelchli
Copy link
Member

@Uroc327 fyi, it turned out there is no bug, but rather a technical thing with iterable datasets. We deciced to add a warning message when IterableDataset defines also length.

In your case you have the following options:

  • remove __len__ from your IterableDataset (this is the preferred option you want 99.9%)
  • convert it to a regular torch.utils.data.Dataset
  • keep the batch size 1 (not great)
  • write a custom BatchSampler (I guess, I have not tried it)

@Uroc327
Copy link
Author

Uroc327 commented Jul 1, 2020

@awaelchli Ok, thanks! I'll remove the __len__ implementation.
Btw, the pytorch docs for DataLoader explicitly allow an IterableDataset with a __len__.

@awaelchli
Copy link
Member

Yep! However as the docs say in the note at the bottom of your link the user has to to their own batching to avoid duplicate data. That's why I added the warning to PL.
Note that for example in your case, leaving it as is len(dataloader) would always return 1024 as the length, regardless of the batch size. That would be incorrect for batch size > 1.

@Uroc327
Copy link
Author

Uroc327 commented Jul 1, 2020

Makes sense, thank you 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants