Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support of other datatypes for Batches #2465

Closed
Tieftaucher opened this issue Jul 2, 2020 · 3 comments
Closed

support of other datatypes for Batches #2465

Tieftaucher opened this issue Jul 2, 2020 · 3 comments
Labels
feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@Tieftaucher
Copy link

Tieftaucher commented Jul 2, 2020

🚀 Feature

My proposal is a support of third party data structures as Batches. At the moment you need to overwrite the transfer_batch_to_device method of your model, if your Batch is not a collection or one of the other supported data types.
My suggestion would be to accept all kinds of data types, as long as they have a to(device)-method.

Motivation

I want to use pytorch_geometric, but there an own dataloader is used and an own Batch-datatype. So I had some trouble using it together with pytorch lighning. After some struggles I figured out to overwrite the transfer_batch_to_device-Method like this:

class Net(pl.LightningModule):
....
    def transfer_batch_to_device(self,batch, device):
        return batch.to(device)

At least I think it would be nice to mention this necessarity in the docs. Or change the default behaviour of transfer_batch_to_device, so that it is no longer necessary.

Pitch

The transfer_batch_to_device should accept all datatypes that contain a "to(device)" method.

Alternatives

Alternative there should be a mentioning in the documentation for using non default dataloader and Batches

Additional context

I saw #1756 but couldnt figure out, if this solves my problem and is just not merged yet or not. If it does, sorry for the extra work.

Thank you for the nice library and all your work =)

@Tieftaucher Tieftaucher added feature Is an improvement or enhancement help wanted Open to be worked on labels Jul 2, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Jul 2, 2020

Hi! thanks for your contribution!, great first issue!

@nghorbani
Copy link

nghorbani commented Aug 28, 2020

This is also affecting me. I notice when using torch-geometric also the batchsize is not effected; i.e. if I set batch_size to 16 for example each call to forward of the lightning module is given 16 data points on "CPU". and all of the data points in each forward call on different gpus are the same!
Using pl 0.9.0 on Ubuntu 20.04

@awaelchli
Copy link
Member

@Tieftaucher I fixed it here: #2335
As long as your datatype implements .to(device), it will call that directly. You don't have to override transfer_batch_to_device in this case.

@nghorbani could you open a separate issue about this. If you provide me some code I can help. But note that, last time I checked, torchgeometry did not support distributed multi gpu (scatter, gather).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

3 participants