-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor model for multi-device usage and easier disabling of masking #95
Conversation
Move the pos_encoding and band_encoding layers to the correct device in a way that allow Lightning to do multi-gpu properly. The reported loss is now synced or reduced/averaged across multiple devices too. Partially cherry-picked from 1a40f56 Co-Authored-By: SRM <soumya@developmentseed.org>
So that the masking can be turned off during prediction using `self.model.encoder.mask_ratio = 0`, where self is an instance of CLAYModule. The num_masked_patches integer value is now calculated on-the-fly by multiplying mask_ratio with num_patches.
src/model_clay.py
Outdated
# Move position & band encoding to the device | ||
self.pos_encoding = self.pos_encoding.to(patches.device) | ||
self.band_encoding = self.band_encoding.to(patches.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srmsoumya, I feel like we can also remove the .to(device)
call here?
# Move position & band encoding to the device | |
self.pos_encoding = self.pos_encoding.to(patches.device) | |
self.band_encoding = self.band_encoding.to(patches.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I think I found the core issue. The pos_encoding
and band_encoding
tensors are declared in the __init__
method, before Lightning has managed to cast them to the correct device.
Lines 89 to 95 in ee74c91
# Fix the position & band embedding to sine & cosine functions | |
self.pos_encoding = posemb_sincos_2d( | |
h=image_size // patch_size, w=image_size // patch_size, dim=pos_dim | |
) # [L D/2] | |
self.band_encoding = posemb_sincos_1d( | |
length=self.num_group_patches, dim=band_dim | |
) # [G D/2] |
Ideally, we'll need to define the device at this point. Let me see if there's a way to do that.
Edit: Yep, found the instructions - https://lightning.ai/docs/pytorch/2.1.0/starter/converting.html#remove-any-cuda-or-to-device-calls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, went with using self.register_buffers
at commit 25f83e6. Have set persistent=False
since those pos_encoding and band_encoding tensors are fixed, and won't need to be saved to the state_dict
.
Since the pos_encoding and band_encoding tensors are declared in the __init__ method, we'll need to register them so that they are moved to the correct device by Lightning during the forward call. See https://lightning.ai/docs/pytorch/2.1.0/starter/converting.html#remove-any-cuda-or-to-device-calls
Ok, gonna merge this patch in directly since others may benefit from the improve cuda/cpu device handling. Plus I'll need these patches for the embedding factory script. |
…#95) * ♻️ Better handle pos and band encodings across multi-devices Move the pos_encoding and band_encoding layers to the correct device in a way that allow Lightning to do multi-gpu properly. The reported loss is now synced or reduced/averaged across multiple devices too. Partially cherry-picked from 1a40f56 Co-Authored-By: SRM <soumya@developmentseed.org> * ♻️ Compute num_masked_patches dynamically based on mask_ratio So that the masking can be turned off during prediction using `self.model.encoder.mask_ratio = 0`, where self is an instance of CLAYModule. The num_masked_patches integer value is now calculated on-the-fly by multiplying mask_ratio with num_patches. * 🎨 Register pos_encoding and band_encoding properly on device Since the pos_encoding and band_encoding tensors are declared in the __init__ method, we'll need to register them so that they are moved to the correct device by Lightning during the forward call. See https://lightning.ai/docs/pytorch/2.1.0/starter/converting.html#remove-any-cuda-or-to-device-calls --------- Co-authored-by: SRM <soumya@developmentseed.org>
A couple of small changes to
model_clay.py
after initial implementation at #47.RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
self.model.encoder.mask_ratio = 0
would disable the masking.First one is cherry-picked from @srmsoumya's
ddp
branch. Second one will be helpful for the embedding generation later on.