-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate embeddings via prediction loop (#56)
* 🍻 Generate embeddings via prediction loop Implement the embedding generator in the LightningModule's predict_step. The embeddings are tensor arrays that are saved to a .npy file in the data/embeddings/ folder. Input data is retrieved from the predict_dataloader, which is currently using the validation datapipe rather than a dedicated datapipe. Have documented how to generate the embedding output file using LightningCLI on the main README.md file. Also added a unit test to ensure that saving and loading from an embedding_0.npy file works. * 🐛 Disable masking of patches on predict_step Previously, 75% of the patches, or 48 out of a total of 64 were masked out, leaving 16 patches plus 1 cls_token = 17 sequences. Disabling the mask gives 64 + 1 cls_token = 65 sequences. Moved some assert statements with a fixed sequence_length dim from the forward function to the training_step. Also updated the unit test to ensure output embeddings have a shape like (batch_size, 65, 768). * ♻️ Refactor LightningDataModule to not do random split on predict Refactoring the setup method in the LightningDataModule to not do a random split on the predict stage. I.e. just do the GeoTIFF to torch.Tensor conversion directly, followed by batching and collating. * ✅ Test predict stage in geotiffdatamodule Need to explicitly pass an argument to stage in the test_geotiffdatapipemodule unit test. Testing both the fit and predict stages. * 👔 Ensure that embeddings have no NaN values Make sure that the generated embeddings do not have NaN values in them. * 🗃️ Take mean of the embeddings along sequence_length dim Instead of saving embeddings of shape (1, 65, 768), save out embeddings of shape (1, 768) instead. Done by taking the mean along the sequence_length dim, except for the cls_token part (first index in the 65).
- Loading branch information
Showing
5 changed files
with
110 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters