Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Vishu26 committed Oct 28, 2023
1 parent b5867fd commit 5a15cfc
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 105 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# BirdSAT: Cross-View Contrastive Masked Autoencoders for Bird Species Classification and Mapping

## 🦢 Dataset Released: Cross-View iNAT Birds 2021
This cross-view birds species dataset consists
of paired ground-level bird images and satellite images, along with meta-information associated with the iNaturalist-2021 dataset.
This cross-view birds species dataset consists of paired ground-level bird images and satellite images, along with meta-information associated with the iNaturalist-2021 dataset.

![CiNAT-Birds-2021](imgs/data.png)

Expand Down
10 changes: 6 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
cfg.pretrain.train.weight_decay = 1e-2
cfg.pretrain.train.accumulate_grad_batches = 1
cfg.pretrain.train.warmup_epochs = 40
cfg.pretrain.train.model_type = "CVEMAE" # one of 'CVEMAE', 'CVMMAE', 'MAE'
cfg.pretrain.train.model_type = "CVEMAE" # one of 'CVEMAE', 'CVMMAE', 'MAE', 'MOCOGEO'
cfg.pretrain.train.expt_name = "CVEMAE_v1"


Expand Down Expand Up @@ -70,7 +70,7 @@
cfg.finetune.train.accumulate_grad_batches = 1
cfg.finetune.train.warmup_epochs = 40
cfg.finetune.train.label_smoothing = 0.05
cfg.finetune.train.model_type = "CVEMAE" # one of 'CVEMAE', 'CVMMAE', 'MAE'
cfg.finetune.train.model_type = "CVEMAE" # one of 'CVEMAE', 'CVMMAE', 'MAE', 'MOCOGEO'
cfg.finetune.train.expt_name = "CVEMAE_finetune_v1"
cfg.finetune.train.dataset = "CUB" # one of 'iNAT', 'CUB', 'NABirds'
cfg.finetune.train.linear_probe = False
Expand All @@ -79,9 +79,11 @@

cfg.retrieval = edict()
cfg.retrieval.enabled = False
cfg.retrieval.model_type = "CVEMAE" # one of 'CVEMAE', 'CVMMAE'
cfg.retrieval.model_type = "CVEMAE" # one of 'CVEMAE', 'CVMMAE' 'MOCOGEO'
cfg.retrieval.mode = "full_metadata" # one of 'no_metadata', 'full_metadata'
cfg.retrieval.batch_size = 77
cfg.retrieval.topk = 10
cfg.retrieval.hierarchical_filter = 50
cfg.retrieval.batch_size = cfg.retrieval.hierarchical_filter
cfg.retrieval.devices = 1
cfg.retrieval.num_workers = 12
cfg.retrieval.ckpt = "checkpoints/CVEMAE_v1-epoch=99-val_loss=0.00.ckpt"
Expand Down
68 changes: 0 additions & 68 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,71 +294,3 @@ def __getitem__(self, idx):
img = Image.open(img_path)
img = self.transform(img)
return img, torch.tensor(label)


class HeirarchicalRet(Dataset):
def __init__(self, dataset, label, ids):
self.dataset = dataset
self.images = np.array(self.dataset["images"])
self.idx = np.array(label.iloc[:, 1]).astype(int)
self.images = self.images[self.idx]
self.ids = ids
self.transform_ground = transforms.Compose(
[
transforms.Resize(cfg.pretrain.ground.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
self.transform_overhead = transforms.Compose(
[
transforms.Resize(cfg.pretrain.overhead.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)

def __len__(self):
return len(self.ids)

def __getitem__(self, idx):
img_path = self.images[self.ids[idx]]["file_name"]
lat = self.images[self.ids[idx]]["latitude"]
lon = self.images[self.ids[idx]]["longitude"]
date = self.images[self.ids[idx]]["date"].split(" ")[0]
month = int(datetime.strptime(date, "%Y-%m-%d").date().strftime("%m"))
day = int(datetime.strptime(date, "%Y-%m-%d").date().strftime("%d"))
date_encode = torch.tensor(
[
np.sin(2 * np.pi * month / 12),
np.cos(2 * np.pi * month / 12),
np.sin(2 * np.pi * day / 31),
np.cos(2 * np.pi * day / 31),
]
)
img_ground = Image.open(os.path.join("data", img_path))
img_ground = self.transform_ground(img_ground)
img_overhead = Image.open(
f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg"
)
img_overhead = self.transform_overhead(img_overhead)
if cfg.pretrain.train.mode == "no_metadata":
return img_ground, img_overhead
else:
return (
img_ground,
img_overhead,
torch.tensor(
[
np.sin(np.pi * lat / 90),
np.cos(np.pi * lat / 90),
np.sin(np.pi * lon / 180),
np.cos(np.pi * lon / 180),
]
).float(),
date_encode.float(),
)
13 changes: 12 additions & 1 deletion finetune.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import pytorch_lightning as pl
from datasets import CrossViewiNATBirdsFineTune
from models import MAE, CVEMAEMeta, CVMMAEMeta
from models import MAE, CVEMAEMeta, CVMMAEMeta, MoCoGeo
from torch.utils.data import random_split
import torch.nn.functional as F
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -41,18 +41,29 @@ def finetune():
train_dataset=train_dataset,
val_dataset=val_dataset,
)
model.setup_finetune()
elif cfg.finetune.train.model_type == "MOCOGEO":
model = MoCoGeo.load_from_checkpoint(
cfg.finetune.train.ckpt,
train_dataset=train_dataset,
val_dataset=val_dataset,
queue_dataset=None,
)
model.setup_finetune()
elif cfg.finetune.train.model_type == "CVEMAE":
model = CVEMAEMeta.load_from_checkpoint(
cfg.finetune.train.ckpt,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
model.setup_finetune()
elif cfg.finetune.train.model_type == "CVMMAE":
model = CVMMAEMeta.load_from_checkpoint(
cfg.finetune.train.ckpt,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
model.setup_finetune()

trainer = pl.Trainer(
accelerator="gpu",
Expand Down
Loading

0 comments on commit 5a15cfc

Please sign in to comment.