Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load data every epoch #231

Closed
sadbb opened this issue Sep 17, 2019 · 12 comments
Closed

How to load data every epoch #231

sadbb opened this issue Sep 17, 2019 · 12 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@sadbb
Copy link

sadbb commented Sep 17, 2019

hi,
because of my task, i must load new train_data every epoch. But in this package, data can only be loaded once at the beginning of training. How can i load data every epoch?
thanks.

@sadbb sadbb added feature Is an improvement or enhancement help wanted Open to be worked on labels Sep 17, 2019
@sadbb sadbb closed this as completed Sep 17, 2019
@sadbb sadbb reopened this Sep 17, 2019
@sadbb sadbb changed the title How to loaddata every epoch How to load data every epoch Sep 17, 2019
@williamFalcon
Copy link
Contributor

could you explain more? do you have pseudocode?

@sadbb
Copy link
Author

sadbb commented Sep 18, 2019

Thanks for help,
overview_flat
As you can see in your picture, data setup called before training loop. But in my task , like few-shot learning , data must be called between "for epoch" and "for batch" in this picture every new epoch. What can i do for this?
Apologize for my poor English.

@sadbb
Copy link
Author

sadbb commented Sep 19, 2019

I made some adjustments. Can you help me check it?
In decorators.py, i delete 'setattr(self, attr_name, value)' in line 25 like this:
1

And in trainer.py, i call data again in function '_train' in line 803 like this:
2

Am I doing this right? Or he has a bad influence elsewhere?
Thank you for help.

@neggert
Copy link
Contributor

neggert commented Sep 23, 2019

What happens if you just don't use the @dataloader decorator? That should prevent the DataLoader from being cached, and you can just recompute it every time the class method is called (which should be once per epoch). Not sure what other effects that would have, though.

@williamFalcon
Copy link
Contributor

@neggert that's the way to do it. I'll add this to the docs.

@williamFalcon
Copy link
Contributor

@sadbb actually just submitted a PR to enable this.
Once we merge into master, just remove the decorators in the dataloaders. Warning though, your loader will be called every epoch. if any of the loaders are slow to init your script will be very slow.

williamFalcon added a commit that referenced this issue Oct 1, 2019
@yassersouri
Copy link
Contributor

I think it is pretty standard to create a dataloader at the beginning of each epoch. I think it should be the default.

@williamFalcon
Copy link
Contributor

it is already default. this PR is to support the non-default case

@neggert
Copy link
Contributor

neggert commented Oct 11, 2019

One callout: When doing validation, val_dataloader gets called every step. This causes performance problems of you haven't memoized the dataloader.

@neggert
Copy link
Contributor

neggert commented Oct 11, 2019

The problem is here, in __evaluation_forward

        if test and len(self.get_test_dataloaders()) > 1:
            args.append(dataloader_idx)

        elif not test and len(self.get_val_dataloaders()) > 1:
            args.append(dataloader_idx)

Honestly, I never really liked passing different args to validation_step depending on the number of dataloaders anyway. Maybe we should think about changing the design slightly here.

@xmodar
Copy link

xmodar commented Nov 14, 2019

I strongly discourage removing the @pl.dataloader decorator because the rest of the package assumes lazy instantiation, as @neggert pointed out. Also, for example, see how many times self.get_train_dataloader() is being called in data_loading_mixin.py. The best alternative, in my opinion, is doing this logic in your Dataset class similar to how Subset works. This way, you will not interrupt the sampler's internal reference to your dataset and you can change it entirely on demand.

from torch.utils.data import Dataset


class DynamicSet(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        return self.dataset[index]

    def __len__(self):
        return len(self.dataset)

    def reset(self, dataset):
        self.dataset = dataset

or you can use indices if desired:

from contextlib import contextmanager
from torch.utils.data import Subset


class DynamicSet(Subset):
    def __init__(self, dataset, indices=None):
        if indices is None:
            indices = list(range(len(self.dataset)))
        super().__init__(dataset, indices)
        self.enumerated = False

    def __getitem__(self, index):
        out = super().__getitem__(index)
        if self.enumerated:
            out = (self.indices[index], out)
        return out

    def reveal(self, indices):
        self.indices = list(set(self.indices).union(indices))

    @property
    def hidden(self):
        return list(set(range(len(self.dataset))) - set(self.indices))

    @contextmanager
    def state(self, indices=None, enumerated=None):
        try:
            if indices is not None:
                old_indices = self.indices
                self.indices = indices
            if enumerated is not None:
                old_enumerated = self.enumerated
                self.enumerated = enumerated
            yield
        finally:
            if indices is not None:
                self.indices = old_indices
            if enumerated is not None:
                self.enumerated = old_enumerated

Usage example (active learning):

# inside model.train_dataloader():
train_set = DynamicSet(dataset, indices=[])

# inside model.on_epoch_start():
train_set = model.train_dataloader().dataset
if not train_set.indices:
    # use the first half of the dataset in the beginning
    train_set.indices = list(range(len(train_set) / 2))
else:
    # add to the dataset the hidden items that passes a threshold
    with train_set.state(train_set.hidden, enumerated=True):
        indices = [i for i, item in train_set if pass_threshold(item)]
    train_set.reveal(indices)

Side Note:

@williamFalcon There is a bug here and here: isinstance(self.get_train_dataloader(), IterableDataset) should be replaced by isinstance(self.get_train_dataloader().dataset, IterableDataset).

@moi90
Copy link
Contributor

moi90 commented Sep 22, 2023

To me, it is still not entirely clear how to achieve this in the best way...

In the past, I used Trainer(reload_dataloaders_every_n_epochs=1) in combination with LightningModule.train_dataloader() and did all the data selection there.

Another option (if just different parts of the same dataset need to be used) would be torch.utils.data.DataLoader(sampler=...) with a custom sampler which resamples the dataset every epoch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

6 participants