Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up dataloader logic #926

Merged
merged 80 commits into from
Feb 25, 2020
Merged

Clean up dataloader logic #926

merged 80 commits into from
Feb 25, 2020

Conversation

williamFalcon
Copy link
Contributor

@williamFalcon williamFalcon commented Feb 24, 2020

Fixes #928
Fixes #927
Fixes #922
Fixes #909
Fixes #859
Fixes #902

Removes data_decorator

# old
@data_loader
def train_data(...):

# new 
def train_data(...):

Adds prepare_data

Lightning needs a step to download data on proc 0 only

def prepare_data():
  # do actual downloads

def train_data(...):
    # return dataloader

Added new flags

# progress bar fast refresh freezes notebooks. here we throttle it
Trainer(progress_bar_refresh_rate=50)

# allow user to reload dataset each epoch
Trainer(reload_dataloaders_every_epoch=False)

Fixes .fit with data

The .fit(dataloaders) was buggy. Simplified it to just hook into the rest of the framework instead of its own adhoc process.

Automatic sampler

Now user doesn't have to mess around with samplers on DDP or TPUs. Lightning sets it up automatically.

@pep8speaks
Copy link

pep8speaks commented Feb 24, 2020

Hello @williamFalcon! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-02-25 03:14:52 UTC

@williamFalcon williamFalcon changed the title Dataloaders [WIP] Dataloaders Feb 25, 2020
@williamFalcon williamFalcon changed the title Dataloaders Clean up dataloader logic Feb 25, 2020
@williamFalcon williamFalcon merged commit 1015a00 into master Feb 25, 2020
@williamFalcon williamFalcon deleted the dataloaders branch February 25, 2020 03:31
dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if a user wants to shuffle batches (when running on a single machine)? i see below that in certain cases you're re-setting this value to False, did you intend to have it set to True here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather move the shuffle to arguments as the others are taken from dataloader and only this is fixed

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is really huge so it is just my quick comments...

def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated

@@ -1,5 +1,6 @@
import traceback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add warning also here

XLA_AVAILABLE = True

except ImportError:
XLA_AVAILABLE = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather

try:
    import torch_xla.core.xla_model as xm
except ImportError:
    XLA_AVAILABLE = False
else:
    XLA_AVAILABLE = True

dl_args = {
'dataset': dataloader.dataset,
'batch_size': dataloader.batch_size,
'shuffle': False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather move the shuffle to arguments as the others are taken from dataloader and only this is fixed

if train:
if self.use_ddp or self.use_ddp2:
sampler = DistributedSampler(dataloader.dataset)
dl_args['shuffle'] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this if it is already fixed as false?

warnings.warn(msg)
break

def init_test_dataloader(self, model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that this can be simply unified ass the content is almost the same

def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be unified...

def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = TestingMNIST(root=self.hparams.data_root, train=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated


# acc
labels_hat = torch.argmax(y_hat, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it tensor already here

return output


class LightningTestFitMultipleTestDataloadersMixin:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not easy to see, what is the difference to LightningTestFitSingleTestDataloadersMixin

@versatran01
Copy link

I still feel this puts too much restriction on the data loader.
A data loader is a very abstract thing, and all we ask from it is the next batch of data. or maybe the size of the dataset.
The current implementation assumes this is a pytorch dataloader and try to access dataloader.batch_size from it.

@williamFalcon
Copy link
Contributor Author

i don’t disagree. maybe a good approach is to check that it’s a pytorch dataloader? what other dataloaders are there?

@versatran01
Copy link

versatran01 commented Feb 26, 2020

What I meant is that lightning should not touch the data loader that the user provides unless necessary.
In the current master branch, the reset_train_dataloader() tries to create a new torch data loader from the given one. And it tries to call auto_add_sampler(), which says it will not do anything if there is a sampler but for some reason, it doesn't. And it will try to create a new dataloader out of the given one. And I cannot find a way to disable it.

https://github.com/PyTorchLightning/pytorch-lightning/blob/be244560b24b68b0236a4694707fb9bb63c2e6d0/pytorch_lightning/trainer/data_loading.py#L92

I can open up an issue if you think that's a better place for this discussion.

@williamFalcon
Copy link
Contributor Author

we could make this a method you can override in lightning module.

what use case do you need to maintain the original loader?

we could also use a flag in the trainer:
auto_add_sampler=True by default

@versatran01
Copy link

versatran01 commented Feb 26, 2020

I have multiple dataloaders that each loads images in order.
And I have my own sampler that makes sure the loaded batch is in order but the batch themselves could be random.
I then use itertools to chain these loaders so that they look like one loader.
This works fine until this pr.
Basically all I need is for lightning to not touch my loader when I don't need fancy features like tpu or ddp and stuff.
What's your suggestion?

The auto_add_sampler functions says it shouldn't do anything when user provides a sampler, we should at least fix this part.

@ethanwharris
Copy link
Member

ethanwharris commented Feb 26, 2020

We're currently adressing this in the fix for #953 - will PR soon.

The solution is to re-write the auto_add_sampler method to only create a dataloader when doing DDP, DDP2 or TPU - and just return the given dataloader otherwise.

@versatran01
Copy link

@ethanwharris I will repost my concerns there then. Thanks for the pointer.

@Borda
Copy link
Member

Borda commented Feb 26, 2020

@versatran01 that would be cool, looking forward to your points 🤖

@Borda Borda added the feature Is an improvement or enhancement label Mar 3, 2020
tullie pushed a commit to tullie/pytorch-lightning that referenced this pull request Apr 3, 2020
* added get dataloaders directly using a getter

* deleted decorator

* added prepare_data hook

* refactored dataloader init

* refactored dataloader init

* added dataloader reset flag and main loop

* added dataloader reset flag and main loop

* added dataloader reset flag and main loop

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* made changes

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed bad loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixed error in .fit with loaders

* fixes Lightning-AI#909

* fixes Lightning-AI#909

* bug fix

* Fixes Lightning-AI#902
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
6 participants