Skip to content

Readme

Readme #11

Triggered via push October 27, 2023 16:58
Status Success
Total duration 24s
Billable time 1m
Artifacts

linting.yml

on: push
Run linters
13s
Run linters
Fit to window
Zoom out
Zoom in

Annotations

11 errors and 1 warning
/home/runner/work/BirdSAT/BirdSAT/Downstream/CUBDownstream.py#L1
-from .MAEPretrain_SceneClassification.models_mae_vitae import mae_vitae_base_patch16_dec512d8b, MaskedAutoencoderViTAE -import torch +from .MAEPretrain_SceneClassification.models_mae_vitae import ( + mae_vitae_base_patch16_dec512d8b, + MaskedAutoencoderViTAE, +) +import torch import torch.nn as nn import pytorch_lightning as pl from pytorch_lightning import LightningModule from torch.utils.data import DataLoader, Dataset, random_split, ConcatDataset from torchvision import transforms, datasets
/home/runner/work/BirdSAT/BirdSAT/Downstream/CUBDownstream.py#L25
from timm.data import Mixup from timm.data import create_transform from timm.loss import SoftTargetCrossEntropy from timm.utils import accuracy + class MaeBirds(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() self.sat_encoder = mae_vitae_base_patch16_dec512d8b() - self.sat_encoder.load_state_dict(torch.load('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth')['model']) + self.sat_encoder.load_state_dict( + torch.load( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth" + )["model"] + ) self.sat_encoder.requires_grad_(False) - self.ground_encoder = MaskedAutoencoderViTAE(img_size=384, patch_size=32, in_chans=3, - embed_dim=768, depth=12, num_heads=12, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False, kernel=3, mlp_hidden_dim=None) + self.ground_encoder = MaskedAutoencoderViTAE( + img_size=384, + patch_size=32, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_pix_loss=False, + kernel=3, + mlp_hidden_dim=None, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 77) - self.num_workers = kwargs.get('num_workers', 16) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 77) + self.num_workers = kwargs.get("num_workers", 16) + self.lr = kwargs.get("lr", 0.02) self.geo_encode = nn.Linear(4, 768) self.date_encode = nn.Linear(4, 768) def forward(self, img_ground, val=False): if not val: - ground_embeddings, *_ = self.ground_encoder.forward_encoder(img_ground, 0.3055) + ground_embeddings, *_ = self.ground_encoder.forward_encoder( + img_ground, 0.3055 + ) return F.normalize(ground_embeddings[:, 0], dim=-1) else: ground_embeddings, *_ = self.ground_encoder.forward_encoder(img_ground, 0) return F.normalize(ground_embeddings[:, 0], dim=-1) + class MaeBirdsDownstream(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() - self.model = MaeBirds.load_from_checkpoint('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveGeoDateMAEv5-epoch=28-val_loss=1.53.ckpt', train_dataset=train_dataset, val_dataset=val_dataset) + self.model = MaeBirds.load_from_checkpoint( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveGeoDateMAEv5-epoch=28-val_loss=1.53.ckpt", + train_dataset=train_dataset, + val_dataset=val_dataset, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 32) - self.num_workers = kwargs.get('num_workers', 16) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 32) + self.num_workers = kwargs.get("num_workers", 16) + self.lr = kwargs.get("lr", 0.02) self.classify = nn.Linear(768, 1486) - #self.criterion = SoftTargetCrossEntropy() + # self.criterion = SoftTargetCrossEntropy() self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1) - self.acc = Accuracy(task='multiclass', num_classes=1486) + self.acc = Accuracy(task="multiclass", num_classes=1486) self.mixup_fn = Mixup( - mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None, - prob=1.0, switch_prob=0.5, mode='batch', - label_smoothing=0.1, num_classes=1486) + mixup_alpha=0.8, + cutmix_alpha=1.0, + cutmix_minmax=None, + prob=1.0, + switch_prob=0.5, + mode="batch", + label_smoothing=0.1, + num_classes=1486, + ) def forward(self, img_ground, val): return self.model(img_ground, val) + class CUBDownstream(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() - self.model = MaeBirdsDownstream.load_from_checkpoint('/storage1/fs1/jacobsn/Active/user_s.sastry/checkpoints/ContrastiveDownstreamGeoMAEv10-epoch=05-val_loss=1.71.ckpt', train_dataset=train_dataset, val_dataset=val_dataset) + self.model = MaeBirdsDownstream.load_from_checkpoint( + "/storage1/fs1/jacobsn/Active/user_s.sastry/checkpoints/ContrastiveDownstreamGeoMAEv10-epoch=05-val_loss=1.71.ckpt", + train_dataset=train_dataset, + val_dataset=val_dataset, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 32) - self.num_workers = kwargs.get('num_workers', 16) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 32) + self.num_workers = kwargs.get("num_workers", 16) + self.lr = kwargs.get("lr", 0.02) self.classify = nn.Linear(768, 200) self.criterion = nn.CrossEntropyLoss() - self.acc = Accuracy(task='multiclass', num_classes=200) + self.acc = Accuracy(task="multiclass", num_classes=200) def forward(self, img_ground, val): return self.classify(self.model(img_ground, val)) def shared_step(self, batch, batch_idx, val=False):
/home/runner/work/BirdSAT/BirdSAT/Downstream/CUBDownstream.py#L97
acc = self.acc(preds, labels) return loss, acc def training_step(self, batch, batch_idx): loss, acc = self.shared_step(batch, batch_idx) - self.log('train_acc', acc, on_epoch=True, prog_bar=True) - self.log('train_loss', loss, prog_bar=True, on_epoch=True) + self.log("train_acc", acc, on_epoch=True, prog_bar=True) + self.log("train_loss", loss, prog_bar=True, on_epoch=True) return {"loss": loss, "acc": acc} def validation_step(self, batch, batch_idx): loss, acc = self.shared_step(batch, batch_idx, True) - self.log('val_acc', acc, prog_bar=True, on_epoch=True) - self.log('val_loss', loss, prog_bar=True, on_epoch=True) - return {"loss": loss, "acc":acc} - + self.log("val_acc", acc, prog_bar=True, on_epoch=True) + self.log("val_loss", loss, prog_bar=True, on_epoch=True) + return {"loss": loss, "acc": acc} + def predict_step(self, batch, batch_idx): acc = self.shared_step(batch, batch_idx) return acc def train_dataloader(self): - return DataLoader(self.train_dataset, - shuffle=True, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=False, - pin_memory=True) + return DataLoader( + self.train_dataset, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=False, + pin_memory=True, + ) def val_dataloader(self): - return DataLoader(self.val_dataset, - shuffle=False, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=True, - pin_memory=True) + return DataLoader( + self.val_dataset, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=True, + pin_memory=True, + ) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=2e-4, weight_decay=0.02) scheduler = CosineAnnealingWarmRestarts(optimizer, 5) return [optimizer], [scheduler] + class CUBBirds(Dataset): def __init__(self, path, val=False): self.path = path self.images = np.loadtxt(os.path.join(self.path, "train_test_split.txt")) if not val: self.images = self.images[self.images[:, 1] == 1] else: self.images = self.images[self.images[:, 1] == 0] - self.img_paths = np.genfromtxt(os.path.join(self.path, 'images.txt'),dtype='str') + self.img_paths = np.genfromtxt( + os.path.join(self.path, "images.txt"), dtype="str" + ) if not val: - self.transform = transforms.Compose([ - transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC), - transforms.RandAugment(12, 12, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.TrivialAugmentWide(num_magnitude_bins=50, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.AugMix(9, 9, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.RandomHorizontalFlip(0.5), - transforms.RandomVerticalFlip(0.5), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) + self.transform = transforms.Compose( + [ + transforms.Resize( + (384, 384), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandAugment( + 12, 12, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.TrivialAugmentWide( + num_magnitude_bins=50, + interpolation=transforms.InterpolationMode.BICUBIC, + ), + transforms.AugMix( + 9, 9, interpolation=transforms.InterpolationMode.BILINEAR + ), + transforms.RandomHorizontalFlip(0.5), + transforms.RandomVerticalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) else: - self.transform = transforms.Compose([ - transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - + self.transform = transforms.Compose( + [ + transforms.Resize( + (384, 384), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) def __len__(self): return len(self.images) def __getitem__(self, idx): - img_path = os.path.join(self.path, 'images/'+ self.img_paths[int(self.images[idx, 0])-1, 1]) - label = int(self.img_paths[int(self.images[idx, 0])-1, 1][:3]) - 1 + img_path = os.path.join( + self.path, "images/" + self.img_paths[int(self.images[idx, 0]) - 1, 1] + ) + label = int(self.img_paths[int(self.images[idx, 0]) - 1, 1][:3]) - 1 img = Image.open(img_path) - if len(np.array(img).shape)==2: - img_path = os.path.join(self.path, 'images/'+ self.img_paths[int(self.images[idx-1, 0])-1, 1]) - label = int(self.img_paths[int(self.images[idx-1, 0])-1, 1][:3]) - 1 + if len(np.array(img).shape) == 2: + img_path = os.path.join( + self.path, + "images/" + self.img_paths[int(self.images[idx - 1, 0]) - 1, 1], + ) + label = int(self.img_paths[int(self.images[idx - 1, 0]) - 1, 1][:3]) - 1 img = Image.open(img_path) - #img = Image.fromarray(np.stack(np.array(img), np.array(img), np.array(img)), axis=-1) + # img = Image.fromarray(np.stack(np.array(img), np.array(img), np.array(img)), axis=-1) img = self.transform(img) return img, torch.tensor(label) -if __name__=='__main__': + +if __name__ == "__main__": f = open("log.txt", "w") - #with redirect_stdout(f), redirect_stderr(f): + # with redirect_stdout(f), redirect_stderr(f): if True: torch.cuda.empty_cache() logger = WandbLogger(project="Fine Grained", name="CUB") - path = '/scratch1/fs1/jacobsn/s.sastry/CUB_200_2011' + path = "/scratch1/fs1/jacobsn/s.sastry/CUB_200_2011" train_dataset = CUBBirds(path) val_dataset = CUBBirds(path, val=True) checkpoint = ModelCheckpoint( - monitor='val_loss', - dirpath='checkpoints', - filename='CUBv1-{epoch:02d}-{val_loss:.2f}', - mode='min' - ) - - + monitor="val_loss", + dirpath="checkpoints", + filename="CUBv1-{epoch:02d}-{val_loss:.2f}", + mode="min", + ) + model = CUBDownstream(train_dataset, val_dataset) - #model = model.load_from_checkpoint("/storage1/fs1/jacobsn/Active/user_s.sastry/checkpoints/ContrastiveDownstreamGeoMAEv7-epoch=94-val_loss=2.77.ckpt", train_dataset=train_dataset, val_dataset=val_dataset) + # model = model.load_from_checkpoint("/storage1/fs1/jacobsn/Active/user_s.sastry/checkpoints/ContrastiveDownstreamGeoMAEv7-epoch=94-val_loss=2.77.ckpt", train_dataset=train_dataset, val_dataset=val_dataset) trainer = pl.Trainer( - accelerator='gpu', + accelerator="gpu", devices=2, - strategy='ddp_find_unused_parameters_true', + strategy="ddp_find_unused_parameters_true", max_epochs=1500, num_nodes=1, callbacks=[checkpoint], - logger=logger - ) + logger=logger, + ) trainer.fit(model) """predloader = DataLoader(train_dataset, shuffle=False, batch_size=64, num_workers=8,
/home/runner/work/BirdSAT/BirdSAT/Downstream/ContGeoMAEDownstream.py#L1
-from .MAEPretrain_SceneClassification.models_mae_vitae import mae_vitae_base_patch16_dec512d8b, MaskedAutoencoderViTAE -import torch +from .MAEPretrain_SceneClassification.models_mae_vitae import ( + mae_vitae_base_patch16_dec512d8b, + MaskedAutoencoderViTAE, +) +import torch import torch.nn as nn import pytorch_lightning as pl from pytorch_lightning import LightningModule from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms, datasets
/home/runner/work/BirdSAT/BirdSAT/Downstream/ContGeoMAEDownstream.py#L25
from timm.data import Mixup from timm.data import create_transform from timm.loss import SoftTargetCrossEntropy from timm.utils import accuracy + class MaeBirds(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() self.sat_encoder = mae_vitae_base_patch16_dec512d8b() - self.sat_encoder.load_state_dict(torch.load('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth')['model']) + self.sat_encoder.load_state_dict( + torch.load( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth" + )["model"] + ) self.sat_encoder.requires_grad_(False) - self.ground_encoder = MaskedAutoencoderViTAE(img_size=384, patch_size=32, in_chans=3, - embed_dim=768, depth=12, num_heads=12, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False, kernel=3, mlp_hidden_dim=None) + self.ground_encoder = MaskedAutoencoderViTAE( + img_size=384, + patch_size=32, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_pix_loss=False, + kernel=3, + mlp_hidden_dim=None, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 77) - self.num_workers = kwargs.get('num_workers', 16) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 77) + self.num_workers = kwargs.get("num_workers", 16) + self.lr = kwargs.get("lr", 0.02) self.geo_encode = nn.Linear(4, 768) self.date_encode = nn.Linear(4, 768) def forward(self, img_ground, geoloc, date): geo_token = self.geo_encode(geoloc) date_token = self.date_encode(date) ground_embeddings, *_ = self.ground_encoder.forward_encoder(img_ground, 0) return F.normalize(ground_embeddings[:, 0] + geo_token + date_token, dim=-1) + class MaeBirdsDownstream(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() - self.model = MaeBirds.load_from_checkpoint('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveGeoDateMAEv5-epoch=28-val_loss=1.53.ckpt', train_dataset=train_dataset, val_dataset=val_dataset) + self.model = MaeBirds.load_from_checkpoint( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveGeoDateMAEv5-epoch=28-val_loss=1.53.ckpt", + train_dataset=train_dataset, + val_dataset=val_dataset, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 64) - self.num_workers = kwargs.get('num_workers', 8) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 64) + self.num_workers = kwargs.get("num_workers", 8) + self.lr = kwargs.get("lr", 0.02) self.classify = nn.Linear(768, 1486) self.criterion = SoftTargetCrossEntropy() - #self.acc = Accuracy(task='multiclass', num_classes=1486) + # self.acc = Accuracy(task='multiclass', num_classes=1486) self.mixup_fn = Mixup( - mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None, - prob=1.0, switch_prob=0.5, mode='batch', - label_smoothing=0.1, num_classes=1486) + mixup_alpha=0.8, + cutmix_alpha=1.0, + cutmix_minmax=None, + prob=1.0, + switch_prob=0.5, + mode="batch", + label_smoothing=0.1, + num_classes=1486, + ) def forward(self, img_ground, geoloc, date): return self.classify(self.model(img_ground, geoloc, date)) def shared_step(self, batch, batch_idx): img_ground, geoloc, date, labels = batch[0], batch[1], batch[2], batch[3] img_ground, labels_mix = self.mixup_fn(img_ground, labels) - #import code; code.interact(local=locals()); + # import code; code.interact(local=locals()); preds = self(img_ground, geoloc, date) - #import code; code.interact(local=locals()); + # import code; code.interact(local=locals()); loss = self.criterion(preds, labels_mix) - #acc = self.acc(preds, labels) + # acc = self.acc(preds, labels) acc = sum(accuracy(preds, labels)) / preds.shape[0] return loss, acc def training_step(self, batch, batch_idx): loss, acc = self.shared_step(batch, batch_idx) - self.log('train_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True) - self.log('train_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) + self.log("train_acc", acc, on_epoch=True, prog_bar=True, sync_dist=True) + self.log("train_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True) return {"loss": loss, "acc": acc} def validation_step(self, batch, batch_idx): loss, acc = self.shared_step(batch, batch_idx) - self.log('val_acc', acc, prog_bar=True, on_epoch=True, sync_dist=True) - self.log('val_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) - return {"loss": loss, "acc":acc} + self.log("val_acc", acc, prog_bar=True, on_epoch=True, sync_dist=True) + self.log("val_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True) + return {"loss": loss, "acc": acc} def train_dataloader(self): - return DataLoader(self.train_dataset, - shuffle=True, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=False, - pin_memory=True) + return DataLoader( + self.train_dataset, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=False, + pin_memory=True, + ) def val_dataloader(self): - return DataLoader(self.val_dataset, - shuffle=False, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=True, - pin_memory=True) + return DataLoader( + self.val_dataset, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=True, + pin_memory=True, + ) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=0.001) scheduler = CosineAnnealingWarmRestarts(optimizer, 40) return [optimizer], [scheduler] + class Birds(Dataset): def __init__(self, dataset, label, val=False): self.dataset = dataset - self.images = np.array(self.dataset['images']) - self.labels = np.array(self.dataset['categories']) + self.images = np.array(self.dataset["images"]) + self.labels = np.array(self.dataset["categories"]) self.species = {} for i in range(len(self.labels)): - self.species[self.labels[i]['id']] = i - self.categories = np.array(self.dataset['annotations']) + self.species[self.labels[i]["id"]] = i + self.categories = np.array(self.dataset["annotations"]) self.idx = np.array(label.iloc[:, 1]).astype(int) self.images = self.images[self.idx] self.categories = self.categories[self.idx] self.val = val if not val: self.transform_ground = create_transform( - input_size=384, - is_training=True, - color_jitter=0.4, - auto_augment='rand-m9-mstd0.5-inc1', - re_prob=0.25, - re_mode='pixel', - re_count=1, - interpolation='bicubic', - ) + input_size=384, + is_training=True, + color_jitter=0.4, + auto_augment="rand-m9-mstd0.5-inc1", + re_prob=0.25, + re_mode="pixel", + re_count=1, + interpolation="bicubic", + ) # self.transform_ground = transforms.Compose([ # transforms.Resize((384, 384)), # transforms.AutoAugment(), # transforms.AugMix(5, 5), # transforms.RandomHorizontalFlip(0.5), # transforms.RandomVerticalFlip(0.5), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ]) else: - self.transform_ground = transforms.Compose([ - transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - + self.transform_ground = transforms.Compose( + [ + transforms.Resize( + (384, 384), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) def __len__(self): return len(self.images) def __getitem__(self, idx): - img_path = self.images[idx]['file_name'] - label = self.species[self.categories[idx]['category_id']] - img_ground = Image.open(os.path.join('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/', img_path)) + img_path = self.images[idx]["file_name"] + label = self.species[self.categories[idx]["category_id"]] + img_ground = Image.open( + os.path.join( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/", img_path + ) + ) img_ground = self.transform_ground(img_ground) - lat = self.images[idx]['latitude'] - lon = self.images[idx]['longitude'] - date = self.images[idx]['date'].split(" ")[0] - month = int(datetime.strptime(date, '%Y-%m-%d').date().strftime('%m')) - day = int(datetime.strptime(date, '%Y-%m-%d').date().strftime('%d')) - date_encode = torch.tensor([np.sin(2*np.pi*month/12), np.cos(2*np.pi*month/12), np.sin(2*np.pi*day/31), np.cos(2*np.pi*day/31)]) - return img_ground, torch.tensor([np.sin(np.pi*lat/90), np.cos(np.pi*lat/90), np.sin(np.pi*lon/180), np.cos(np.pi*lon/180)]).float(), date_encode.float(), torch.tensor(label) - -if __name__=='__main__': + lat = self.images[idx]["latitude"] + lon = self.images[idx]["longitude"] + date = self.images[idx]["date"].split(" ")[0] + month = int(datetime.strptime(date, "%Y-%m-%d").date().strftime("%m")) + day = int(datetime.strptime(date, "%Y-%m-%d").date().strftime("%d")) + date_encode = torch.tensor( + [ + np.sin(2 * np.pi * month / 12), + np.cos(2 * np.pi * month / 12), + np.sin(2 * np.pi * day / 31), + np.cos(2 * np.pi * day / 31), + ] + ) + return ( + img_ground, + torch.tensor( + [ + np.sin(np.pi * lat / 90), + np.cos(np.pi * lat / 90), + np.sin(np.pi * lon / 180), + np.cos(np.pi * lon / 180), + ] + ).float(), + date_encode.float(), + torch.tensor(label), + ) + + +if __name__ == "__main__": f = open("log.txt", "w") - #with redirect_stdout(f), redirect_stderr(f): + # with redirect_stdout(f), redirect_stderr(f): if True: torch.cuda.empty_cache() logger = WandbLogger(project="Cross-View-MAE", name="Downstram Cont MAE") - train_dataset = json.load(open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/train_birds.json")) - train_labels = pd.read_csv('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/train_birds_labels.csv') + train_dataset = json.load( + open( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/train_birds.json" + ) + ) + train_labels = pd.read_csv( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/train_birds_labels.csv" + ) train_dataset = Birds(train_dataset, train_labels) - val_dataset = json.load(open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds.json")) - val_labels = pd.read_csv('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds_labels.csv') + val_dataset = json.load( + open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds.json") + ) + val_labels = pd.read_csv( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds_labels.csv" + ) val_dataset = Birds(val_dataset, val_labels, val=True) checkpoint = ModelCheckpoint( - monitor='val_loss', - dirpath='checkpoints', - filename='ContrastiveDownstreamGeoMAEv7-{epoch:02d}-{val_loss:.2f}', - mode='min' - ) - - + monitor="val_loss", + dirpath="checkpoints", + filename="ContrastiveDownstreamGeoMAEv7-{epoch:02d}-{val_loss:.2f}", + mode="min", + ) + model = MaeBirdsDownstream(train_dataset, val_dataset) - #model = model.load_from_checkpoint("/storage1/fs1/jacobsn/Active/user_s.sastry/checkpoints/ContrastiveDownstreamGeoMAEv7-epoch=94-val_loss=2.77.ckpt", train_dataset=train_dataset, val_dataset=val_dataset) + # model = model.load_from_checkpoint("/storage1/fs1/jacobsn/Active/user_s.sastry/checkpoints/ContrastiveDownstreamGeoMAEv7-epoch=94-val_loss=2.77.ckpt", train_dataset=train_dataset, val_dataset=val_dataset) trainer = pl.Trainer( - accelerator='gpu', + accelerator="gpu", devices=4, - strategy='ddp_find_unused_parameters_true', + strategy="ddp_find_unused_parameters_true", max_epochs=1500, num_nodes=1, callbacks=[checkpoint], - logger=logger - ) + logger=logger, + ) trainer.fit(model)
/home/runner/work/BirdSAT/BirdSAT/Downstream/CrossMAEDownstream.py#L1
-from .MAEPretrain_SceneClassification.models_mae_vitae import mae_vitae_base_patch16_dec512d8b, MaskedAutoencoderViTAE -import torch +from .MAEPretrain_SceneClassification.models_mae_vitae import ( + mae_vitae_base_patch16_dec512d8b, + MaskedAutoencoderViTAE, +) +import torch import torch.nn as nn import pytorch_lightning as pl from pytorch_lightning import LightningModule from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms, datasets
/home/runner/work/BirdSAT/BirdSAT/Downstream/CrossMAEDownstream.py#L30
class MaeBirds(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() self.sat_encoder = mae_vitae_base_patch16_dec512d8b() - self.sat_encoder.load_state_dict(torch.load('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth')['model']) + self.sat_encoder.load_state_dict( + torch.load( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth" + )["model"] + ) self.sat_encoder.requires_grad_(False) - self.ground_encoder = MaskedAutoencoderViTAE(img_size=384, patch_size=32, in_chans=3, - embed_dim=768, depth=12, num_heads=12, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False, kernel=3, mlp_hidden_dim=None) + self.ground_encoder = MaskedAutoencoderViTAE( + img_size=384, + patch_size=32, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_pix_loss=False, + kernel=3, + mlp_hidden_dim=None, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 77) - self.num_workers = kwargs.get('num_workers', 16) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 77) + self.num_workers = kwargs.get("num_workers", 16) + self.lr = kwargs.get("lr", 0.02) def forward(self, img_ground): ground_embeddings, *_ = self.ground_encoder.forward_encoder(img_ground, 0) return F.normalize(ground_embeddings[:, 0], dim=-1) + class MaeBirdsDownstream(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() - self.model = MaeBirds.load_from_checkpoint('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveMAEv5-epoch=44-val_loss=1.60.ckpt', train_dataset=train_dataset, val_dataset=val_dataset) + self.model = MaeBirds.load_from_checkpoint( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveMAEv5-epoch=44-val_loss=1.60.ckpt", + train_dataset=train_dataset, + val_dataset=val_dataset, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 64) - self.num_workers = kwargs.get('num_workers', 8) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 64) + self.num_workers = kwargs.get("num_workers", 8) + self.lr = kwargs.get("lr", 0.02) self.classify = nn.Linear(768, 1486) self.criterion = SoftTargetCrossEntropy() - #self.acc = Accuracy(task='multiclass', num_classes=1486) + # self.acc = Accuracy(task='multiclass', num_classes=1486) self.mixup_fn = Mixup( - mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None, - prob=1.0, switch_prob=0.5, mode='batch', - label_smoothing=0.1, num_classes=1486) + mixup_alpha=0.8, + cutmix_alpha=1.0, + cutmix_minmax=None, + prob=1.0, + switch_prob=0.5, + mode="batch", + label_smoothing=0.1, + num_classes=1486, + ) def forward(self, img_ground): return self.classify(self.model(img_ground)) def shared_step(self, batch, batch_idx): img_ground, labels = batch[0], batch[1] img_ground, labels_mix = self.mixup_fn(img_ground, labels) - #import code; code.interact(local=locals()); + # import code; code.interact(local=locals()); preds = self(img_ground) loss = self.criterion(preds, labels_mix) - #acc = self.acc(preds, labels) + # acc = self.acc(preds, labels) acc = sum(accuracy(preds, labels)) / preds.shape[0] return loss, acc def training_step(self, batch, batch_idx): loss, acc = self.shared_step(batch, batch_idx) - self.log('train_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True) - self.log('train_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) + self.log("train_acc", acc, on_epoch=True, prog_bar=True, sync_dist=True) + self.log("train_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True) return {"loss": loss, "acc": acc} def validation_step(self, batch, batch_idx): loss, acc = self.shared_step(batch, batch_idx) - self.log('val_acc', acc, prog_bar=True, on_epoch=True, sync_dist=True) - self.log('val_loss', loss, prog_bar=True, on_epoch=True, sync_dist=True) - return {"loss": loss, "acc":acc} + self.log("val_acc", acc, prog_bar=True, on_epoch=True, sync_dist=True) + self.log("val_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True) + return {"loss": loss, "acc": acc} def train_dataloader(self): - return DataLoader(self.train_dataset, - shuffle=True, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=False, - pin_memory=True) + return DataLoader( + self.train_dataset, + shuffle=True, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=False, + pin_memory=True, + ) def val_dataloader(self): - return DataLoader(self.val_dataset, - shuffle=False, - batch_size=self.batch_size, - num_workers=self.num_workers, - persistent_workers=True, - pin_memory=True) + return DataLoader( + self.val_dataset, + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=True, + pin_memory=True, + ) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=0.001) scheduler = CosineAnnealingWarmRestarts(optimizer, 40) return [optimizer], [scheduler] + class Birds(Dataset): def __init__(self, dataset, label, val=False): self.dataset = dataset - self.images = np.array(self.dataset['images']) - self.labels = np.array(self.dataset['categories']) + self.images = np.array(self.dataset["images"]) + self.labels = np.array(self.dataset["categories"]) self.species = {} for i in range(len(self.labels)): - self.species[self.labels[i]['id']] = i - self.categories = np.array(self.dataset['annotations']) + self.species[self.labels[i]["id"]] = i + self.categories = np.array(self.dataset["annotations"]) self.idx = np.array(label.iloc[:, 1]).astype(int) self.images = self.images[self.idx] self.categories = self.categories[self.idx] self.val = val if not val: self.transform_ground = create_transform( - input_size=384, - is_training=True, - color_jitter=0.4, - auto_augment='rand-m9-mstd0.5-inc1', - re_prob=0.25, - re_mode='pixel', - re_count=1, - interpolation='bicubic', - ) + input_size=384, + is_training=True, + color_jitter=0.4, + auto_augment="rand-m9-mstd0.5-inc1", + re_prob=0.25, + re_mode="pixel", + re_count=1, + interpolation="bicubic", + ) # self.transform_ground = transforms.Compose([ # transforms.Resize((384, 384)), # transforms.AutoAugment(), # transforms.AugMix(5, 5), # transforms.RandomHorizontalFlip(0.5), # transforms.RandomVerticalFlip(0.5), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ]) else: - self.transform_ground = transforms.Compose([ - transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - + self.transform_ground = transforms.Compose( + [ + transforms.Resize( + (384, 384), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) def __len__(self): return len(self.images) def __getitem__(self, idx): - img_path = self.images[idx]['file_name'] - label = self.species[self.categories[idx]['category_id']] - img_ground = Image.open(os.path.join('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/', img_path)) + img_path = self.images[idx]["file_name"] + label = self.species[self.categories[idx]["category_id"]] + img_ground = Image.open( + os.path.join( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/", img_path + ) + ) img_ground = self.transform_ground(img_ground) return img_ground, torch.tensor(label) -if __name__=='__main__': + +if __name__ == "__main__": f = open("log.txt", "w") with redirect_stdout(f), redirect_stderr(f): torch.cuda.empty_cache() logger = WandbLogger(project="Cross-View-MAE", name="Downstram Cont MAE") - train_dataset = json.load(open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/train_birds.json")) - train_labels = pd.read_csv('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/train_birds_labels.csv') + train_dataset = json.load( + open( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/train_birds.json" + ) + ) + train_labels = pd.read_csv( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/train_birds_labels.csv" + ) train_dataset = Birds(train_dataset, train_labels) - val_dataset = json.load(open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds.json")) - val_labels = pd.read_csv('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds_labels.csv') + val_dataset = json.load( + open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds.json") + ) + val_labels = pd.read_csv( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds_labels.csv" + ) val_dataset = Birds(val_dataset, val_labels, val=True) checkpoint = ModelCheckpoint( - monitor='val_loss', - dirpath='checkpoints', - filename='ContrastiveDownstreamMAEv8-{epoch:02d}-{val_loss:.2f}', - mode='min' - ) - - + monitor="val_loss", + dirpath="checkpoints", + filename="ContrastiveDownstreamMAEv8-{epoch:02d}-{val_loss:.2f}", + mode="min", + ) + model = MaeBirdsDownstream(train_dataset, val_dataset) - #model = model.load_from_checkpoint("/storage1/fs1/jacobsn/Active/user_s.sastry/checkpoints/ContrastiveDownstreamMAEv8-epoch=59-val_loss=3.54.ckpt", train_dataset=train_dataset, val_dataset=val_dataset) + # model = model.load_from_checkpoint("/storage1/fs1/jacobsn/Active/user_s.sastry/checkpoints/ContrastiveDownstreamMAEv8-epoch=59-val_loss=3.54.ckpt", train_dataset=train_dataset, val_dataset=val_dataset) trainer = pl.Trainer( - accelerator='gpu', + accelerator="gpu", devices=4, - strategy='ddp_find_unused_parameters_true', + strategy="ddp_find_unused_parameters_true", max_epochs=1500, num_nodes=1, callbacks=[checkpoint], - logger=logger) + logger=logger, + ) trainer.fit(model)
/home/runner/work/BirdSAT/BirdSAT/Retrieval/RecallContGeoMAE.py#L1
-from .MAEPretrain_SceneClassification.models_mae_vitae import mae_vitae_base_patch16_dec512d8b, MaskedAutoencoderViTAE -import torch +from .MAEPretrain_SceneClassification.models_mae_vitae import ( + mae_vitae_base_patch16_dec512d8b, + MaskedAutoencoderViTAE, +) +import torch import torch.nn as nn import pytorch_lightning as pl from pytorch_lightning import LightningModule from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms, datasets
/home/runner/work/BirdSAT/BirdSAT/Retrieval/RecallContGeoMAE.py#L27
class MaeBirds(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() self.sat_encoder = mae_vitae_base_patch16_dec512d8b() - self.sat_encoder.load_state_dict(torch.load('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth')['model']) + self.sat_encoder.load_state_dict( + torch.load( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth" + )["model"] + ) self.sat_encoder.requires_grad_(False) - self.ground_encoder = MaskedAutoencoderViTAE(img_size=384, patch_size=32, in_chans=3, - embed_dim=768, depth=12, num_heads=12, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False, kernel=3, mlp_hidden_dim=None) + self.ground_encoder = MaskedAutoencoderViTAE( + img_size=384, + patch_size=32, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_pix_loss=False, + kernel=3, + mlp_hidden_dim=None, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 77) - self.num_workers = kwargs.get('num_workers', 16) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 77) + self.num_workers = kwargs.get("num_workers", 16) + self.lr = kwargs.get("lr", 0.02) self.geo_encode = nn.Linear(4, 768) self.date_encode = nn.Linear(4, 768) def forward(self, img_ground, img_overhead, geoloc, date): geo_token = self.geo_encode(geoloc) date_token = self.date_encode(date) ground_embeddings, *_ = self.ground_encoder.forward_encoder(img_ground, 0) sat_embeddings, *_ = self.sat_encoder.forward_encoder(img_overhead, 0) - return F.normalize(ground_embeddings[:, 0], dim=-1), F.normalize(sat_embeddings[:, 0] + geo_token + date_token, dim=-1) + return F.normalize(ground_embeddings[:, 0], dim=-1), F.normalize( + sat_embeddings[:, 0] + geo_token + date_token, dim=-1 + ) + class Birds(Dataset): def __init__(self, dataset, label, val=False): self.dataset = dataset - self.images = np.array(self.dataset['images']) - self.labels = np.array(self.dataset['categories']) + self.images = np.array(self.dataset["images"]) + self.labels = np.array(self.dataset["categories"]) self.species = {} for i in range(len(self.labels)): - self.species[self.labels[i]['id']] = i - self.categories = np.array(self.dataset['annotations']) + self.species[self.labels[i]["id"]] = i + self.categories = np.array(self.dataset["annotations"]) self.idx = np.array(label.iloc[:, 1]).astype(int) self.images = self.images[self.idx] self.categories = self.categories[self.idx] self.val = val if not val: - self.transform_ground = transforms.Compose([ - transforms.Resize((384, 384)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - self.transform_overhead = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) + self.transform_ground = transforms.Compose( + [ + transforms.Resize((384, 384)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.transform_overhead = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) else: - self.transform_ground = transforms.Compose([ - transforms.Resize((384, 384)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - self.transform_overhead = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - + self.transform_ground = transforms.Compose( + [ + transforms.Resize((384, 384)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.transform_overhead = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) def __len__(self): return len(self.images) def __getitem__(self, idx): - img_path = self.images[idx]['file_name'] - lat = self.images[idx]['latitude'] - lon = self.images[idx]['longitude'] - date = self.images[idx]['date'].split(" ")[0] - month = int(datetime.strptime(date, '%Y-%m-%d').date().strftime('%m')) - day = int(datetime.strptime(date, '%Y-%m-%d').date().strftime('%d')) - date_encode = torch.tensor([np.sin(2*np.pi*month/12), np.cos(2*np.pi*month/12), np.sin(2*np.pi*day/31), np.cos(2*np.pi*day/31)]) - label = self.species[self.categories[idx]['category_id']] - img_ground = Image.open(os.path.join('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/', img_path)) + img_path = self.images[idx]["file_name"] + lat = self.images[idx]["latitude"] + lon = self.images[idx]["longitude"] + date = self.images[idx]["date"].split(" ")[0] + month = int(datetime.strptime(date, "%Y-%m-%d").date().strftime("%m")) + day = int(datetime.strptime(date, "%Y-%m-%d").date().strftime("%d")) + date_encode = torch.tensor( + [ + np.sin(2 * np.pi * month / 12), + np.cos(2 * np.pi * month / 12), + np.sin(2 * np.pi * day / 31), + np.cos(2 * np.pi * day / 31), + ] + ) + label = self.species[self.categories[idx]["category_id"]] + img_ground = Image.open( + os.path.join( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/", img_path + ) + ) img_ground = self.transform_ground(img_ground) if not self.val: - img_overhead = Image.open(f"/scratch1/fs1/jacobsn/s.sastry/metaformer/train_overhead/images_sentinel/{idx}.jpeg") + img_overhead = Image.open( + f"/scratch1/fs1/jacobsn/s.sastry/metaformer/train_overhead/images_sentinel/{idx}.jpeg" + ) else: - img_overhead = Image.open(f"/scratch1/fs1/jacobsn/s.sastry/metaformer/val_overhead/images_sentinel/{idx}.jpeg") + img_overhead = Image.open( + f"/scratch1/fs1/jacobsn/s.sastry/metaformer/val_overhead/images_sentinel/{idx}.jpeg" + ) img_overhead = self.transform_overhead(img_overhead) - return img_ground, img_overhead, torch.tensor([np.sin(np.pi*lat/90), np.cos(np.pi*lat/90), np.sin(np.pi*lon/180), np.cos(np.pi*lon/180)]).float(), date_encode.float(), torch.tensor(label) - -if __name__=='__main__': + return ( + img_ground, + img_overhead, + torch.tensor( + [ + np.sin(np.pi * lat / 90), + np.cos(np.pi * lat / 90), + np.sin(np.pi * lon / 180), + np.cos(np.pi * lon / 180), + ] + ).float(), + date_encode.float(), + torch.tensor(label), + ) + + +if __name__ == "__main__": torch.cuda.empty_cache() - val_dataset = json.load(open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds.json")) - val_labels = pd.read_csv('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds_labels.csv') + val_dataset = json.load( + open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds.json") + ) + val_labels = pd.read_csv( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds_labels.csv" + ) val_dataset = Birds(val_dataset, val_labels, val=True) - model = MaeBirds.load_from_checkpoint('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveGeoDateMAEv5-epoch=28-val_loss=1.53.ckpt', train_dataset=val_dataset, val_dataset=val_dataset) - + model = MaeBirds.load_from_checkpoint( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveGeoDateMAEv5-epoch=28-val_loss=1.53.ckpt", + train_dataset=val_dataset, + val_dataset=val_dataset, + ) + model = model.eval() - val_overhead = DataLoader(val_dataset, - shuffle=False, - batch_size=77, - num_workers=8, - persistent_workers=False, - pin_memory=True, - drop_last=True - ) - + val_overhead = DataLoader( + val_dataset, + shuffle=False, + batch_size=77, + num_workers=8, + persistent_workers=False, + pin_memory=True, + drop_last=True, + ) + recall = 0 for batch in tqdm(val_overhead): - #for batch2 in tqdm(val_overhead): + # for batch2 in tqdm(val_overhead): img_ground, img_overhead, geoloc, date, label = batch - z = 0 + z = 0 running_val = 0 running_label = 0 for batch2 in tqdm(val_overhead): img_ground2, img_overhead2, geoloc2, date2, label2 = batch2 - ground_embeddings, overhead_embeddings = model(img_ground2.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda()) - similarity = torch.einsum('ij,kj->ik', ground_embeddings, overhead_embeddings) + ground_embeddings, overhead_embeddings = model( + img_ground2.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda() + ) + similarity = torch.einsum( + "ij,kj->ik", ground_embeddings, overhead_embeddings + ) vals, ind = torch.topk(similarity.detach().cpu(), 5, dim=0) - if z==0: + if z == 0: running_val = vals running_label = label2[ind] - z+=1 + z += 1 else: running_val = torch.cat((running_val, vals), dim=0) running_label = torch.cat((running_label, label2[ind]), dim=0) _, ind = torch.topk(running_val, 5, dim=0) - #import code; code.interact(local=locals()) + # import code; code.interact(local=locals()) preds = running_label[ind] - recall+=sum([1 if label[i] in preds[:, i] else 0 for i in range(label.shape[0])]) - #import code; code.interact(local=locals()) - print(f"Current Recall Score: {recall}") + recall += sum( + [1 if label[i] in preds[:, i] else 0 for i in range(label.shape[0])] + ) + # import code; code.interact(local=locals()) + print(f"Current Recall Score: {recall}")
/home/runner/work/BirdSAT/BirdSAT/Retrieval/RecallContMAE.py#L1
-from .MAEPretrain_SceneClassification.models_mae_vitae import mae_vitae_base_patch16_dec512d8b, MaskedAutoencoderViTAE -import torch +from .MAEPretrain_SceneClassification.models_mae_vitae import ( + mae_vitae_base_patch16_dec512d8b, + MaskedAutoencoderViTAE, +) +import torch import torch.nn as nn import pytorch_lightning as pl from pytorch_lightning import LightningModule from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms, datasets
/home/runner/work/BirdSAT/BirdSAT/Retrieval/RecallContMAE.py#L22
import copy import os from tqdm import tqdm from functools import partial + class MaeBirds(LightningModule): def __init__(self, train_dataset, val_dataset, **kwargs): super().__init__() self.sat_encoder = mae_vitae_base_patch16_dec512d8b() - self.sat_encoder.load_state_dict(torch.load('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth')['model']) + self.sat_encoder.load_state_dict( + torch.load( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/vitae-b-checkpoint-1599-transform-no-average.pth" + )["model"] + ) self.sat_encoder.requires_grad_(False) - self.ground_encoder = MaskedAutoencoderViTAE(img_size=384, patch_size=32, in_chans=3, - embed_dim=768, depth=12, num_heads=12, - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False, kernel=3, mlp_hidden_dim=None) + self.ground_encoder = MaskedAutoencoderViTAE( + img_size=384, + patch_size=32, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_pix_loss=False, + kernel=3, + mlp_hidden_dim=None, + ) self.train_dataset = train_dataset self.val_dataset = val_dataset - self.batch_size = kwargs.get('batch_size', 77) - self.num_workers = kwargs.get('num_workers', 16) - self.lr = kwargs.get('lr', 0.02) + self.batch_size = kwargs.get("batch_size", 77) + self.num_workers = kwargs.get("num_workers", 16) + self.lr = kwargs.get("lr", 0.02) def forward(self, img_ground, img_overhead): ground_embeddings, *_ = self.ground_encoder.forward_encoder(img_ground, 0) sat_embeddings, *_ = self.sat_encoder.forward_encoder(img_overhead, 0) return ground_embeddings[:, 0], sat_embeddings[:, 0] + class Birds(Dataset): def __init__(self, dataset, label, val=False): self.dataset = dataset - self.images = np.array(self.dataset['images']) - self.labels = np.array(self.dataset['categories']) + self.images = np.array(self.dataset["images"]) + self.labels = np.array(self.dataset["categories"]) self.species = {} for i in range(len(self.labels)): - self.species[self.labels[i]['id']] = i - self.categories = np.array(self.dataset['annotations']) + self.species[self.labels[i]["id"]] = i + self.categories = np.array(self.dataset["annotations"]) self.idx = np.array(label.iloc[:, 1]).astype(int) self.images = self.images[self.idx] self.categories = self.categories[self.idx] self.val = val if not val: - self.transform_ground = transforms.Compose([ - transforms.Resize((384, 384)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - self.transform_overhead = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) + self.transform_ground = transforms.Compose( + [ + transforms.Resize((384, 384)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.transform_overhead = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) else: - self.transform_ground = transforms.Compose([ - transforms.Resize((384, 384)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - self.transform_overhead = transforms.Compose([ - transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - ]) - + self.transform_ground = transforms.Compose( + [ + transforms.Resize((384, 384)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.transform_overhead = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) def __len__(self): return len(self.images) def __getitem__(self, idx): - img_path = self.images[idx]['file_name'] - label = self.species[self.categories[idx]['category_id']] - img_ground = Image.open(os.path.join('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/', img_path)) + img_path = self.images[idx]["file_name"] + label = self.species[self.categories[idx]["category_id"]] + img_ground = Image.open( + os.path.join( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/", img_path + ) + ) img_ground = self.transform_ground(img_ground) if not self.val: - img_overhead = Image.open(f"/scratch1/fs1/jacobsn/s.sastry/metaformer/train_overhead/images_sentinel/{idx}.jpeg") + img_overhead = Image.open( + f"/scratch1/fs1/jacobsn/s.sastry/metaformer/train_overhead/images_sentinel/{idx}.jpeg" + ) else: - img_overhead = Image.open(f"/scratch1/fs1/jacobsn/s.sastry/metaformer/val_overhead/images_sentinel/{idx}.jpeg") + img_overhead = Image.open( + f"/scratch1/fs1/jacobsn/s.sastry/metaformer/val_overhead/images_sentinel/{idx}.jpeg" + ) img_overhead = self.transform_overhead(img_overhead) return img_ground, img_overhead, torch.tensor(label) -if __name__=='__main__': + +if __name__ == "__main__": torch.cuda.empty_cache() - val_dataset = json.load(open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds.json")) - val_labels = pd.read_csv('/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds_labels.csv') + val_dataset = json.load( + open("/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds.json") + ) + val_labels = pd.read_csv( + "/storage1/fs1/jacobsn/Active/user_s.sastry/metaformer/val_birds_labels.csv" + ) val_dataset = Birds(val_dataset, val_labels, val=True) - model = MaeBirds.load_from_checkpoint('/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveMAEv5-epoch=44-val_loss=1.60.ckpt', train_dataset=val_dataset, val_dataset=val_dataset) + model = MaeBirds.load_from_checkpoint( + "/storage1/fs1/jacobsn/Active/user_s.sastry/Remote-Sensing-RVSA/ContrastiveMAEv5-epoch=44-val_loss=1.60.ckpt", + train_dataset=val_dataset, + val_dataset=val_dataset, + ) model = model.eval() - val_overhead = DataLoader(val_dataset, - shuffle=False, - batch_size=77, - num_workers=8, - persistent_workers=False, - pin_memory=True, - drop_last=True - ) - + val_overhead = DataLoader( + val_dataset, + shuffle=False, + batch_size=77, + num_workers=8, + persistent_workers=False, + pin_memory=True, + drop_last=True, + ) + recall = 0 for batch in tqdm(val_overhead): - #for batch2 in tqdm(val_overhead): + # for batch2 in tqdm(val_overhead): img_ground, img_overhead, label = batch - z = 0 + z = 0 running_val = 0 running_label = 0 for batch2 in tqdm(val_overhead): img_ground2, img_overhead2, label2 = batch2 - ground_embeddings, overhead_embeddings = model(img_ground2.cuda(), img_overhead.cuda()) + ground_embeddings, overhead_embeddings = model( + img_ground2.cuda(), img_overhead.cuda() + ) norm_ground_features = F.normalize(ground_embeddings, dim=-1) norm_overhead_features = F.normalize(overhead_embeddings, dim=-1) - similarity = torch.einsum('ij,kj->ik', norm_ground_features, norm_overhead_features) + similarity = torch.einsum( + "ij,kj->ik", norm_ground_features, norm_overhead_features + ) vals, ind = torch.topk(similarity.detach().cpu(), 10, dim=0) - if z==0: + if z == 0: running_val = vals running_label = label2[ind] - z+=1 + z += 1 else: running_val = torch.cat((running_val, vals), dim=0) running_label = torch.cat((running_label, label2[ind]), dim=0) _, ind = torch.topk(running_val, 10, dim=0) - #import code; code.interact(local=locals()) + # import code; code.interact(local=locals()) preds = running_label[ind] - recall+=sum([1 if label[i] in preds[:, i] else 0 for i in range(label.shape[0])]) - #import code; code.interact(local=locals()) + recall += sum( + [1 if label[i] in preds[:, i] else 0 for i in range(label.shape[0])] + ) + # import code; code.interact(local=locals()) print(f"Current Recall Score: {recall}")
Run linters
The following actions uses node12 which is deprecated and will be forced to run on node16: actions/checkout@v2, actions/setup-python@v1. For more info: https://github.blog/changelog/2023-06-13-github-actions-all-actions-will-run-on-node16-instead-of-node12-by-default/