Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model for multi-device usage and easier disabling of masking #95

Merged
merged 3 commits into from
Dec 20, 2023

Conversation

weiji14
Copy link
Contributor

@weiji14 weiji14 commented Dec 20, 2023

A couple of small changes to model_clay.py after initial implementation at #47.

  • Handle layers across cpu/gpu devices a bit better, to prevent errors like RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
  • Let the masking in the Masked Autoencoder be dynamic, so that setting self.model.encoder.mask_ratio = 0 would disable the masking.

First one is cherry-picked from @srmsoumya's ddp branch. Second one will be helpful for the embedding generation later on.

weiji14 and others added 2 commits December 20, 2023 17:11
Move the pos_encoding and band_encoding layers to the correct device in a way that allow Lightning to do multi-gpu properly. The reported loss is now synced or reduced/averaged across multiple devices too. Partially cherry-picked from 1a40f56

Co-Authored-By: SRM <soumya@developmentseed.org>
So that the masking can be turned off during prediction using `self.model.encoder.mask_ratio = 0`, where self is an instance of CLAYModule. The num_masked_patches integer value is now calculated on-the-fly by multiplying mask_ratio with num_patches.
@weiji14 weiji14 added the model-architecture Pull requests about the neural network model architecture label Dec 20, 2023
@weiji14 weiji14 self-assigned this Dec 20, 2023
Comment on lines 282 to 284
# Move position & band encoding to the device
self.pos_encoding = self.pos_encoding.to(patches.device)
self.band_encoding = self.band_encoding.to(patches.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srmsoumya, I feel like we can also remove the .to(device) call here?

Suggested change
# Move position & band encoding to the device
self.pos_encoding = self.pos_encoding.to(patches.device)
self.band_encoding = self.band_encoding.to(patches.device)

Copy link
Contributor Author

@weiji14 weiji14 Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think I found the core issue. The pos_encoding and band_encoding tensors are declared in the __init__ method, before Lightning has managed to cast them to the correct device.

model/src/model_clay.py

Lines 89 to 95 in ee74c91

# Fix the position & band embedding to sine & cosine functions
self.pos_encoding = posemb_sincos_2d(
h=image_size // patch_size, w=image_size // patch_size, dim=pos_dim
) # [L D/2]
self.band_encoding = posemb_sincos_1d(
length=self.num_group_patches, dim=band_dim
) # [G D/2]

Ideally, we'll need to define the device at this point. Let me see if there's a way to do that.

Edit: Yep, found the instructions - https://lightning.ai/docs/pytorch/2.1.0/starter/converting.html#remove-any-cuda-or-to-device-calls

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, went with using self.register_buffers at commit 25f83e6. Have set persistent=False since those pos_encoding and band_encoding tensors are fixed, and won't need to be saved to the state_dict.

src/model_clay.py Outdated Show resolved Hide resolved
Since the pos_encoding and band_encoding tensors are declared in the __init__ method, we'll need to register them so that they are moved to the correct device by Lightning during the forward call. See https://lightning.ai/docs/pytorch/2.1.0/starter/converting.html#remove-any-cuda-or-to-device-calls
@weiji14 weiji14 marked this pull request as ready for review December 20, 2023 08:32
@weiji14
Copy link
Contributor Author

weiji14 commented Dec 20, 2023

Ok, gonna merge this patch in directly since others may benefit from the improve cuda/cpu device handling. Plus I'll need these patches for the embedding factory script.

@weiji14 weiji14 merged commit d05635d into main Dec 20, 2023
2 checks passed
@weiji14 weiji14 deleted the refactor-model-device-and-dynamic-masking branch December 20, 2023 08:50
brunosan pushed a commit that referenced this pull request Dec 27, 2023
…#95)

* ♻️ Better handle pos and band encodings across multi-devices

Move the pos_encoding and band_encoding layers to the correct device in a way that allow Lightning to do multi-gpu properly. The reported loss is now synced or reduced/averaged across multiple devices too. Partially cherry-picked from 1a40f56

Co-Authored-By: SRM <soumya@developmentseed.org>

* ♻️ Compute num_masked_patches dynamically based on mask_ratio

So that the masking can be turned off during prediction using `self.model.encoder.mask_ratio = 0`, where self is an instance of CLAYModule. The num_masked_patches integer value is now calculated on-the-fly by multiplying mask_ratio with num_patches.

* 🎨 Register pos_encoding and band_encoding properly on device

Since the pos_encoding and band_encoding tensors are declared in the __init__ method, we'll need to register them so that they are moved to the correct device by Lightning during the forward call. See https://lightning.ai/docs/pytorch/2.1.0/starter/converting.html#remove-any-cuda-or-to-device-calls

---------

Co-authored-by: SRM <soumya@developmentseed.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model-architecture Pull requests about the neural network model architecture
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant