Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Vishu26 committed Oct 28, 2023
1 parent 5a15cfc commit f94ea09
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 31 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ An example of task 3 is shown below:

## 👨‍💻 Getting Started

#### Setting up
1. Clone this repository:
```bash
git clone https://github.com/mvrl/BirdSAT.git
```
2. Clone the Remote-Sensing-RVSA repository inside BirdSAT
```bash
cd BirdSAT
git clone https://github.com/ViTAE-Transformer/Remote-Sensing-RVSA.git
```
3. **Append** the code for CVMMAE present in `utils_model/CVMMAE.py` to the file present in `Remote-Sensing-RVSA/MAEPretrain_SceneClassification/models_mae_vitae.py`

#### Installing Required Packages
There are two options to setup your environment to be able to run all the functions in the repository:
1. Using Dockerfile provided in the repository to create a docker image with all required packages:
Expand All @@ -29,7 +41,6 @@ There are two options to setup your environment to be able to run all the functi
conda activate birdsat && \
pip install requirements.txt
```

Additionally, we have hosted a pre-built docker image on docker hub with tag `srikumar26/birdsat:latest` for use.

## 🔥 Training Models
Expand Down
70 changes: 40 additions & 30 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .MAEPretrain_SceneClassification.models_mae_vitae import (
from Remote_Sensing_RVSA.MAEPretrain_SceneClassification.models_mae_vitae import (
mae_vitae_base_patch16_dec512d8b,
MaskedAutoencoderViTAE,
CrossViewMaskedAutoencoder,
Expand Down Expand Up @@ -69,9 +69,7 @@ def forward_features(self, img_ground):
embeddings, *_ = self.model.forward_encoder(img_ground, 0)
return embeddings

def forward_finetune(self, img_ground, label):
img_ground, label = self.mixup(img_ground, label)
img_ground, label = self.cutmix(img_ground, label)
def forward_finetune(self, img_ground):
embeddings, *_ = self.model.forward_encoder(img_ground, 0)
out = self.classify(embeddings[:, 0])
return out
Expand All @@ -81,9 +79,12 @@ def shared_step_pretrain(self, batch, batch_idx):
loss_recon = self(img_ground)
return loss_recon

def shared_step_finetune(self, batch, batch_idx):
def shared_step_finetune(self, batch, batch_idx, eval=False):
img_ground, labels = batch[0], batch[1]
preds = self.forward_finetune(img_ground, labels)
if not eval:
img_ground, labels = self.mixup(img_ground, labels)
img_ground, labels = self.cutmix(img_ground, labels)
preds = self.forward_finetune(img_ground)
loss = self.criterion(preds, labels)
acc = self.acc(preds, labels)
return loss, acc
Expand All @@ -96,7 +97,7 @@ def training_step(self, batch, batch_idx):
)
return {"loss": loss_recon}
else:
loss, acc = self.shared_step_finetune(batch, batch_idx)
loss, acc = self.shared_step_finetune(batch, batch_idx, eval=False)
self.log("train_acc", acc, on_epoch=True, prog_bar=True)
self.log("train_loss", loss, prog_bar=True, on_epoch=True)
return {"loss": loss, "acc": acc}
Expand All @@ -109,7 +110,7 @@ def validation_step(self, batch, batch_idx):
)
return {"loss": loss_recon}
else:
loss, acc = self.shared_step_finetune(batch, batch_idx)
loss, acc = self.shared_step_finetune(batch, batch_idx, eval=True)
self.log("val_acc", acc, prog_bar=True, on_epoch=True)
self.log("val_loss", loss, prog_bar=True, on_epoch=True)
return {"loss": loss, "acc": acc}
Expand Down Expand Up @@ -169,7 +170,7 @@ def __init__(self, trainset, validset, queueset):
self.sat_encoder = mae_vitae_base_patch16_dec512d8b()
self.sat_encoder.load_state_dict(
torch.load(
"/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth"
"pretrained_models/vitae-b-checkpoint-1599-transform-no-average.pth"
)["model"]
)
self.sat_encoder.requires_grad_(False)
Expand Down Expand Up @@ -263,8 +264,11 @@ def shared_step_pretrain(self, batch):

return loss, loss_cont, loss_geo

def shared_step_finetune(self, batch, batch_idx):
def shared_step_finetune(self, batch, batch_idx, eval=False):
ground_img, labels = batch[0], batch[1]
if not eval:
ground_img, labels = self.mixup(ground_img, labels)
ground_img, labels = self.cutmix(ground_img, labels)
preds = self.forward_finetune(ground_img)
loss = self.criterion(preds, labels)
acc = self.acc(preds, labels)
Expand Down Expand Up @@ -307,7 +311,7 @@ def training_step(self, batch, batch_idx):
)
self.log("train_geo_loss", train_geo_loss, prog_bar=True, sync_dist=True)
else:
loss, acc = self.shared_step_finetune(batch, batch_idx)
loss, acc = self.shared_step_finetune(batch, batch_idx, eval=False)
self.log("train_acc", acc, on_epoch=True, prog_bar=True)
self.log("train_loss", loss, prog_bar=True, on_epoch=True)
return {"loss": loss, "acc": acc}
Expand All @@ -329,7 +333,7 @@ def validation_step(self, batch, batch_idx):
)
self.log("val_geo_loss", val_geo_loss, sync_dist=True, prog_bar=True)
else:
loss, acc = self.shared_step_finetune(batch, batch_idx)
loss, acc = self.shared_step_finetune(batch, batch_idx, eval=True)
self.log("val_acc", acc, prog_bar=True, on_epoch=True)
self.log("val_loss", loss, prog_bar=True, on_epoch=True)
return {"loss": loss, "acc": acc}
Expand Down Expand Up @@ -508,9 +512,7 @@ def forward_features(self, img_ground, img_overhead, geoloc=None, date=None):
)
return norm_ground_features, norm_overhead_features

def forward_finetune(self, img_ground, label, geoloc=None, date=None):
img_ground, label = self.mixup(img_ground, label)
img_ground, label = self.cutmix(img_ground, label)
def forward_finetune(self, img_ground, geoloc=None, date=None):
ground_embeddings, *_ = self.ground_encoder.forward_encoder(img_ground, 0)
if geoloc is None or date is None:
norm_ground_features = F.normalize(ground_embeddings[:, 0], dim=-1)
Expand Down Expand Up @@ -541,13 +543,19 @@ def shared_step_pretrain(self, batch, batch_idx):
loss = 0.3 * loss_clip + loss_recon
return loss, loss_clip, loss_recon

def shared_step_finetune(self, batch, batch_idx):
def shared_step_finetune(self, batch, batch_idx, eval=False):
if cfg.pretrain.train.mode == "no_metadata":
img_ground, labels = batch[0], batch[1]
preds = self.forward_finetune(img_ground, labels)
if not eval:
img_ground, labels = self.mixup(img_ground, labels)
img_ground, labels = self.cutmix(img_ground, labels)
preds = self.forward_finetune(img_ground)
else:
img_ground, labels, geoloc, date = batch[0], batch[1], batch[2], batch[3]
preds = self.forward_finetune(img_ground, labels, geoloc, date)
if not eval:
img_ground, labels = self.mixup(img_ground, labels)
img_ground, labels = self.cutmix(img_ground, labels)
preds = self.forward_finetune(img_ground, geoloc, date)
loss = self.criterion(preds, labels)
acc = self.acc(preds, labels)
return loss, acc
Expand All @@ -572,7 +580,7 @@ def training_step(self, batch, batch_idx):
)
return {"loss": loss, "loss_clip": loss_clip, "loss_recon": loss_recon}
else:
loss, acc = self.shared_step_finetune(batch, batch_idx)
loss, acc = self.shared_step_finetune(batch, batch_idx, eval=False)
self.log("train_acc", acc, on_epoch=True, prog_bar=True, sync_dist=True)
self.log("train_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True)
return {"loss": loss, "acc": acc}
Expand All @@ -593,7 +601,7 @@ def validation_step(self, batch, batch_idx):
)
return {"loss": loss, "loss_clip": loss_clip, "loss_recon": loss_recon}
else:
loss, acc = self.shared_step_finetune(batch, batch_idx)
loss, acc = self.shared_step_finetune(batch, batch_idx, eval=True)
self.log("val_acc", acc, prog_bar=True, on_epoch=True, sync_dist=True)
self.log("val_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True)
return {"loss": loss, "acc": acc}
Expand Down Expand Up @@ -721,9 +729,7 @@ def forward_features(self, img_ground, img_overhead, geoloc=None, date=None):
)
return norm_embeddings, self.match(norm_embeddings)

def forward_finetune(self, img_ground, img_overhead, label, geoloc=None, date=None):
img_ground, label = self.mixup(img_ground, label)
img_ground, label = self.cutmix(img_ground, label)
def forward_finetune(self, img_ground, img_overhead, geoloc=None, date=None):
embeddings, *_ = self.model.forward_encoder(img_ground, img_overhead, 0)
if cfg.pretrain.train.mode == "no_metadata":
norm_embeddings = F.normalize(embeddings[:, 0], dim=-1)
Expand Down Expand Up @@ -754,10 +760,13 @@ def shared_step_pretrain(self, batch, batch_idx):
loss = loss_matching + loss_recon
return loss, loss_matching, loss_recon

def shared_step_finetune(self, batch, batch_idx):
def shared_step_finetune(self, batch, batch_idx, eval=False):
if cfg.pretrain.train.mode == "no_metadata":
img_ground, img_overhead, labels = batch[0], batch[1], batch[2]
preds = self.forward_finetune(img_ground, img_overhead, labels)
if not eval:
img_ground, label = self.mixup(img_ground, label)
img_ground, label = self.cutmix(img_ground, label)
preds = self.forward_finetune(img_ground, img_overhead)
else:
img_ground, img_overhead, labels, geoloc, date = (
batch[0],
Expand All @@ -766,9 +775,10 @@ def shared_step_finetune(self, batch, batch_idx):
batch[3],
batch[4],
)
preds = self.forward_finetune(
img_ground, img_overhead, labels, geoloc, date
)
if not eval:
img_ground, label = self.mixup(img_ground, label)
img_ground, label = self.cutmix(img_ground, label)
preds = self.forward_finetune(img_ground, img_overhead, geoloc, date)
loss = self.criterion(preds, labels)
acc = self.acc(preds, labels)
return loss, acc
Expand All @@ -793,7 +803,7 @@ def training_step(self, batch, batch_idx):
)
return {"loss": loss, "loss_clip": loss_clip, "loss_recon": loss_recon}
else:
loss, acc = self.shared_step_finetune(batch, batch_idx)
loss, acc = self.shared_step_finetune(batch, batch_idx, eval=False)
self.log("train_acc", acc, on_epoch=True, prog_bar=True, sync_dist=True)
self.log("train_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True)
return {"loss": loss, "acc": acc}
Expand All @@ -814,7 +824,7 @@ def validation_step(self, batch, batch_idx):
)
return {"loss": loss, "loss_clip": loss_clip, "loss_recon": loss_recon}
else:
loss, acc = self.shared_step_finetune(batch, batch_idx)
loss, acc = self.shared_step_finetune(batch, batch_idx, eval=True)
self.log("val_acc", acc, prog_bar=True, on_epoch=True, sync_dist=True)
self.log("val_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True)
return {"loss": loss, "acc": acc}
Expand Down
Loading

0 comments on commit f94ea09

Please sign in to comment.