-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Device #1790
Device #1790
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1790 +/- ##
======================================
- Coverage 88% 88% -0%
======================================
Files 69 69
Lines 4304 4312 +8
======================================
+ Hits 3796 3801 +5
- Misses 508 511 +3 |
If the LightningModule gets used in a context outside of Lightning (simply as an nn.Module) then moving the module with |
model.to(xm.xla_device()) | ||
self.device = xm.xla_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for example here, we could simply override the .to
in LightningModule and no extra code in the Trainer is necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the code would be in one place and therefore easier to maintain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, .to
calls submodules too, so this approach would automatically take care of nested LightingModules!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wait... does nn.Module have a .device??
if so, i don't think we should overwrite no? i thought the weights had it not the module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i need to sleep lol.. so, i can think about it tomorrow... but sounds interesting :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no it doesn't have it. see here:
pytorch/pytorch#7460
I guess if we do it then we would run into issues when the user starts to move their submodules to different devices by hand. but that would anyway be a problem :)
yes let's do it tomorrow :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could do this as a read-only property similar to what I did on metrics. But I also agree, this should be read-only and not be used for device transfers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exactly like you did in metrics, that's exactly what I meant 👍 Nice
…n DeviceStatsMonitor Minor refactor to use the strategy's own `root_device` instead of the LightningModule's device property. Attempts at manual model parallelization by extending this plugin will face difficulties with the assumption that the LightningModule has all of its parameters on the same device. For those use cases, it is critical to remove the assumption that the module has a device property (device in general goes against PyTorch module's design principles: - pytorch/pytorch#7460 - #1790 (comment)
* Use trainer.strategy.root_device in favor of LightningModule.device in DeviceStatsMonitor Minor refactor to use the strategy's own `root_device` instead of the LightningModule's device property. Attempts at manual model parallelization by extending this plugin will face difficulties with the assumption that the LightningModule has all of its parameters on the same device. For those use cases, it is critical to remove the assumption that the module has a device property (device in general goes against PyTorch module's design principles: - pytorch/pytorch#7460 - #1790 (comment)
add self.device pointer to lightningModule