Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Vision Transformer architecture with MAE decoder #37

Merged
merged 14 commits into from
Nov 21, 2023

Conversation

weiji14
Copy link
Contributor

@weiji14 weiji14 commented Nov 16, 2023

What I am changing

  • Initializing the neural network architecture, with a Vision Transformer (ViT) B/32 backbone and Masked Autoencoder (MAE) decoder

How I did it

  • ViT backbone/decoder architecture is from HuggingFace transformers
    • Loosely using the ViT B/32 model, but with 12 channels instead of 3.

Note:

TODO:

  • Install transformers dependency
  • Initialize model architecture backbone/decoder layers
  • Setup training_step and forward pass
  • Add unit tests
  • Document model architecture in src/README.md

How you can test it

  • Run python trainer.py fit --trainer.max_epochs=10 locally

Related Issues

Working towards #3

References:

  • He, K., Chen, X., Xie, S., Li, Y., Dollar, P., & Girshick, R. (2022). Masked Autoencoders Are Scalable Vision Learners. 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 15979–15988. https://doi.org/10.1109/CVPR52688.2022.01553

Somehow using the `--with-cuda=11.8` flag in conda-lock didn't work as expected to get the CUDA-built Pytorch instead of the CPU version. Temporarily downgrading from Pytorch 2.1 to 2.0 and CUDA 11.8 to 11.2, to make it possible to install torchvision=0.15.2 from conda-forge later.
A deep learning package for self-supervised learning!
Initializing the neural network architecture layers, specifically a Vision Transformer (ViT) B/32 backbone and a Masked Autoencoder (MAE) decoder. Using Lightly for the MAE setup, with the ViT backbone from torchvision. Setup is mostly adapted from https://github.com/lightly-ai/lightly/blob/v1.4.21/examples/pytorch_lightning/mae.py
@weiji14 weiji14 added the model-architecture Pull requests about the neural network model architecture label Nov 16, 2023
@weiji14 weiji14 added this to the v0 Release milestone Nov 16, 2023
@weiji14 weiji14 self-assigned this Nov 16, 2023
State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow!
Changing from lightly/torchvision's ViTMAE implementation to HuggingFace transformers's ViTMAE. This allows us to configure the number of input channels to a number other than 3 (e.g. 12). However, transformer's ViTMAE is an all-in-one class rather than an Encoder/Decoder split (though there's a way to access either once the class is instantiated). Allowed for configuring the masking_ratio instead of the decoder_dim size, and removed the MSE loss because it is implemented in the ViTMAE class already.
Run input images through the encoder and decoder, and compute the pixel reconstruction loss from training the Masked Autoencoder.
Ensure that running one training step on a mini-batch works. Created a random torch Dataset that generates tensors of shape (12, 256, 256) until there is real data to train on.
No need to pin to CUDA 11.2 since not using torchvision anymore. Patches 06535cd
The datacube has 13 channels, namely 10 from Sentinel-2's 10m and 20m resolution bands, 2 from Sentinel-1's VV and VH, and 1 from the Copernicus DEM.
Use a variable self.B instead of hardcoding 32 as the batch_size in the assert statements checking the tensor shape, so that the last mini-batch with a size less than 32 can be seen by the model.
@weiji14 weiji14 marked this pull request as ready for review November 21, 2023 05:16
Copy link
Collaborator

@srmsoumya srmsoumya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks good to me & we have enough options to modify in the MAE & ViT backbone.
Let us use this model for current sprint, next week we need to add options to:

  1. Add embeddings for time, lat/lon, channels & position
  2. Implement different masking strategy like random masking, grouped channel/time masking.
  3. Add support for different backbones like SWIN or FlexiVIT

Rename MAELitModule to ViTLitModule, and model.py to model_vit.py, since we might be trying out different neural network model architectures later.
@weiji14
Copy link
Contributor Author

weiji14 commented Nov 21, 2023

Thanks @srm, I see you're starting the spatiotemporal embedding work on GeoViT at #47, and we can work on the masking strategy and different model backbones later too. I've renamed model.py to model_vit.py in case we want to have other model_*.py files. Will merge this into the main branch now.

@weiji14 weiji14 merged commit 17f4698 into main Nov 21, 2023
2 checks passed
@weiji14 weiji14 deleted the model/init-vit branch November 21, 2023 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model-architecture Pull requests about the neural network model architecture
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants