Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IterableDataset support #323

Closed
antvconst opened this issue Oct 7, 2019 · 19 comments
Closed

Add IterableDataset support #323

antvconst opened this issue Oct 7, 2019 · 19 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@antvconst
Copy link
Contributor

Looks like currently there is no way to use an IterableDataset instance for training. Trying to do so results in a crash with this exception:

Traceback (most recent call last):
  File "main.py", line 12, in <module>
    trainer.fit(model)
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 677, in fit
    self.__single_gpu_train(model)
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 725, in __single_gpu_train
    self.__run_pretrain_routine(model)
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 882, in __run_pretrain_routine
    self.__layout_bookeeping()
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 436, in __layout_bookeeping
    self.nb_training_batches = len(self.train_dataloader)
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 297, in __len__
    return len(self._index_sampler)  # with iterable-style dataset, this will error
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/torch/utils/data/sampler.py", line 212, in __len__
    return (len(self.sampler) + self.batch_size - 1) // self.batch_size
  File "/home/akonstantinov/.miniconda3/envs/cpn/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 57, in __len__
    raise TypeError('Cannot determine the DataLoader length of a IterableDataset')
TypeError: Cannot determine the DataLoader length of a IterableDataset
@antvconst antvconst added the bug Something isn't working label Oct 7, 2019
@williamFalcon
Copy link
Contributor

Got it.

Since it's impossible to know when to check validation, checkpoint or stop training the workaround is to add a really high number to the len of your dataloader.

@williamFalcon williamFalcon added feature Is an improvement or enhancement help wanted Open to be worked on and removed bug Something isn't working labels Oct 10, 2019
@williamFalcon williamFalcon changed the title IterableDataset is not supported Add IterableDataset support Oct 10, 2019
@williamFalcon
Copy link
Contributor

To support this we should modify the training loop so it does validation check, etc... every k batches. Might need to disable tqdm limit as well because we won't know the length.

The use case is for streaming data or database-type reads.

@antvconst
Copy link
Contributor Author

There is also a bit more simple case, when the length is actually known, but random access by index is not available. That is true in my case: my dataset generates samples on the fly, but always a fixed amount per epoch. Say, every epoch 10k samples are generated and fed into the model by batches on 100 samples.

@williamFalcon
Copy link
Contributor

so, actually all of this can be solved by adding a way to say how many batches you want per epoch. then everything just works out.

Trainer(max_epoch_batches=10000)

@williamFalcon
Copy link
Contributor

@neggert any thoughts?

@neggert
Copy link
Contributor

neggert commented Oct 11, 2019

In the past, I've handled this like by storing a num_batches attribute in custom batch sampler (which I needed to use for other reasons). Then we just do this:

    def __len__(self):
        return self.num_batches

    def _get_batch(self):
        ...

    def __iter__(self):
        return iter((self._get_batch() for _ in range(self.num_batches))

This is probably not a good general solution, as it would have been a lot of work if I hadn't been planning on using a custom batch sampler anyway.

For a general solution, think a max_epoch_batches arg is a good idea.

We do need to be a little bit careful, as there are some pitfalls around using an IterableDataset with multiple workers or nodes. It would be good to warn users about these like we do with DistributedSampler.

@williamFalcon
Copy link
Contributor

williamFalcon commented Oct 21, 2019

So, the resolution here is to add an argument:

Trainer(max_epoch_batches=10000) which calls validation set at that interval (and overrides all the other settings for this).

(open to a better name for this)

@Borda @neggert any of u guys want to take a stab at this?

@williamFalcon
Copy link
Contributor

williamFalcon commented Oct 22, 2019

Added in #405

To use:

Trainer(val_check_interval=100)

(checks val every 100 train batches interval)

@falceeffect please verify it works as requested. Otherwise we can reopen

@armancohan
Copy link

Doesn't solve the problem

@calclavia
Copy link
Contributor

Same here. Still give an error

@williamFalcon
Copy link
Contributor

@calclavia which problem? compatibility? (@MikeScarp)

@armancohan
Copy link

@williamFalcon The original issue reported here is not fixed by #405. I am still unable to train with an instance of ItrableDataset.

@calclavia
Copy link
Contributor

Right. Putting in val_check_interval works but it doesn't seem to circumvent Pytorch lightning asking for len of the dataset, which leads to a crash since IterableDataset doesn't support len call

@williamFalcon
Copy link
Contributor

@calclavia mind submitting a PR? where is the len being asked? i thought we specifically handled that case

@absudabsu
Copy link

absudabsu commented Dec 12, 2019

I tried defining the __len__(self) in my dataset class (which inherits from torch.utils.data.IterableDataset)), but it still didn't work.
For me, the actual error occurs on the line 297 of torch/utils/data/dataloader.py, when it tries to call len(self._index_sampler).
It would be cool if pytorch-lightning supported a IterableDatasets by calling the right torch functions.

@felixlaumon
Copy link

It appears there is a typo in the latest pip installable version 0.5.3.2 https://github.com/williamFalcon/pytorch-lightning/blob/0.5.3.2/pytorch_lightning/trainer/data_loading_mixin.py#L27.

It should be isinstance(self.get_train_dataloader().dataset, IterableDataset) instead of isinstance(self.get_train_dataloader(), IterableDataset)

This is later fixed in #549 but it has not been released.

@williamFalcon will you be able to make a new release soon since there was no release in December? Thanks!

@matthew-z
Copy link
Contributor

Actually, even the latest master branch still has this problem.

Traceback (most recent call last):
  File "scripts/msmacro.py", line 119, in <module>
    main()
  File "scripts/msmacro.py", line 115, in main
    trainer.fit(model)
  File "/home/zhaohao/.anaconda3/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 417, in fit
    self.run_pretrain_routine(model)
  File "/home/zhaohao/.anaconda3/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 481, in run_pretrain_routine
    self.get_dataloaders(ref_model)
  File "/home/zhaohao/.anaconda3/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 199, in get_dataloaders
    self.init_train_dataloader(model)
  File "/home/zhaohao/.anaconda3/envs/pytorch/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 78, in init_train_dataloader
    self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
OverflowError: cannot convert float infinity to integer

@Borda
Copy link
Member

Borda commented Jan 17, 2020

@matthew-z Could you please reopen this issue or make another one?

@matthew-z
Copy link
Contributor

I don't have the privilege to re-open a closed issue, so I will open a new one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

9 participants