Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

device property #1791

Merged
merged 9 commits into from
May 13, 2020
Merged

device property #1791

merged 9 commits into from
May 13, 2020

Conversation

Borda
Copy link
Member

@Borda Borda commented May 12, 2020

What does this PR do?

Fixes #1790 (comment)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@Borda Borda added the feature Is an improvement or enhancement label May 12, 2020
Copy link
Member

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it read only :)

consider also my comment here
#1790 (comment)
for better code style and flexibility

self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could remove all of these calls in the trainer by overloading .to() and .cuda() in LightningModule and setting the device there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to overload .cpu() :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, maybe i’m missiny something. The point of self.device is to have a readonly property to create tensors in memory directly.

Copy link
Member

@awaelchli awaelchli May 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we overload the the .to() method like this for example:

def to(self, device):
    self._device = device
    return super().to(device)
    

Then we get the following benefits:

  • self.device property will not break when LightningModule is used as nn.Module without Trainer
  • When LightningModule is a nested LightningModule and user calls .to(), also the self.device properties of submodules get updated
  • The Trainer code does not need to set the device, it calls .to anyway, so the code is in one place and is easier to maintain.

I see only benefits atm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justusschock also did it like this for metrics package

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we need to overwrite the following methods:

  • .to(...)
  • .cpu()
  • .cuda()
    or am I missing any? @awaelchli ^^

Copy link
Member

@awaelchli awaelchli May 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, exactly, although I suspect cpu and cuda already call .to internally. Not sure, need to check. EDIT: nope they don't we need all three :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it seems to me that ideally, we want to rise the whole template from metrics...

Copy link
Member

@awaelchli awaelchli May 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not all. device.setter, dtype does not apply for LightningModule I think? I agree we should try to avoid code duplication.
@justusschock what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, while we do this, we should think about introducing the same for dtype, since when I create a tensor in a function, it usually involves a certain dtype as well. Although I'm not sure, if this would be reflected by amp as well...

@mergify mergify bot requested a review from a team May 12, 2020 09:09
@mergify mergify bot requested a review from a team May 12, 2020 09:12
@codecov
Copy link

codecov bot commented May 12, 2020

Codecov Report

Merging #1791 into master will decrease coverage by 0%.
The diff coverage is 64%.

@@          Coverage Diff           @@
##           master   #1791   +/-   ##
======================================
- Coverage      88%     88%   -0%     
======================================
  Files          69      69           
  Lines        4316    4322    +6     
======================================
+ Hits         3805    3809    +4     
- Misses        511     513    +2     

@Borda Borda changed the title device property [wip] device property May 12, 2020
@williamFalcon
Copy link
Contributor

yeah, good catch. This is meant as read-only.

the motivation is to support tensors on device directly.

torch.rand(..., device=self.device)

self.device = torch.device('cuda', self.root_gpu)
self._device = torch.device('cuda', self.root_gpu)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, maybe i’m missiny something. The point of self.device is to have a readonly property to create tensors in memory directly.

@williamFalcon williamFalcon changed the title [wip] device property device property May 12, 2020
@mergify
Copy link
Contributor

mergify bot commented May 12, 2020

Great job! =)

@williamFalcon williamFalcon changed the title device property [wip] device property May 12, 2020
@Borda Borda changed the title [wip] device property device property May 12, 2020
@Borda
Copy link
Member Author

Borda commented May 12, 2020

@awaelchli @justusschock I have just copy-pasted the basic template from Metrics as it contains all we need now and later we can just inherit it back... are you fine with this solution?

Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably need to rise the corresponding tests as well :)

pytorch_lightning/core/properties.py Outdated Show resolved Hide resolved
pytorch_lightning/core/properties.py Outdated Show resolved Hide resolved
pytorch_lightning/core/properties.py Outdated Show resolved Hide resolved
pytorch_lightning/core/properties.py Outdated Show resolved Hide resolved
pytorch_lightning/core/properties.py Outdated Show resolved Hide resolved
@justusschock
Copy link
Member

Also I'd probably rename the Mixin to reflect which properties it provides to something like DeviceDtypeModuleMixin

@mergify mergify bot requested a review from a team May 12, 2020 12:54
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
@Borda
Copy link
Member Author

Borda commented May 12, 2020

seems to fail on unrelated doctest (I have not changed Trainer)
looks like ellipses does not work...

@Borda Borda requested a review from justusschock May 12, 2020 13:32
Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one question

@@ -529,6 +529,10 @@ def __init__(
# Callback system
self.on_init_end()

@property
def device(self) -> Union[None, str, object]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it is not clear, why the trainer should have such a property as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was there before, I just made it as read-only, but agree that it is strange

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove it, it's not needed as far as I can tell, since now it has shifted over to the module

@williamFalcon
Copy link
Contributor

i love how this escalated haha

@Borda Borda added this to the 0.7.6 milestone May 12, 2020
@Borda
Copy link
Member Author

Borda commented May 12, 2020

it seems that there is API change in pt 1.5

    device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
ValueError: too many values to unpack (expected 3)

but it is strange that the torch master uses the same https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

@williamFalcon williamFalcon merged commit 10ce1c0 into master May 13, 2020
@Borda Borda deleted the feature/device branch May 13, 2020 06:47
@justusschock
Copy link
Member

@Borda Just FYI: it doesn't use the same. There is an additional argument for formatting introduced...
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L453

@Borda
Copy link
Member Author

Borda commented May 13, 2020

@Borda Just FYI: it doesn't use the same. There is an additional argument for formatting introduced...
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L453

yes, but pt < 1.5 has only three output vars, right...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants