From b5867fd4dae025b7a6c32ff1e159586386d3e991 Mon Sep 17 00:00:00 2001 From: Vishu26 Date: Sat, 28 Oct 2023 00:42:26 -0500 Subject: [PATCH] Polished --- datasets.py | 52 +++++++++++++++++++++++++++-------------------- retrieval_eval.py | 34 +++++++++++++++++-------------- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/datasets.py b/datasets.py index d0c913a..4b599fc 100644 --- a/datasets.py +++ b/datasets.py @@ -150,15 +150,21 @@ def __init__(self, dataset, label, val=False): ) else: self.transform_ground = transforms.Compose( - [transforms.Resize(cfg.pretrain.ground.img_size), transforms.ToTensor(), transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] + [ + transforms.Resize(cfg.pretrain.ground.img_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] ) self.transform_overhead = transforms.Compose( [ transforms.Resize(cfg.pretrain.overhead.img_size), transforms.ToTensor(), transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), ] ) @@ -298,27 +304,27 @@ def __init__(self, dataset, label, ids): self.images = self.images[self.idx] self.ids = ids self.transform_ground = transforms.Compose( - [ - transforms.Resize(cfg.pretrain.ground.img_size), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), - ] - ) + [ + transforms.Resize(cfg.pretrain.ground.img_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) self.transform_overhead = transforms.Compose( - [ - transforms.Resize(cfg.pretrain.overhead.img_size), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), - ] - ) - + [ + transforms.Resize(cfg.pretrain.overhead.img_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + def __len__(self): return len(self.ids) - + def __getitem__(self, idx): img_path = self.images[self.ids[idx]]["file_name"] lat = self.images[self.ids[idx]]["latitude"] @@ -336,7 +342,9 @@ def __getitem__(self, idx): ) img_ground = Image.open(os.path.join("data", img_path)) img_ground = self.transform_ground(img_ground) - img_overhead = Image.open(f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg") + img_overhead = Image.open( + f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg" + ) img_overhead = self.transform_overhead(img_overhead) if cfg.pretrain.train.mode == "no_metadata": return img_ground, img_overhead diff --git a/retrieval_eval.py b/retrieval_eval.py index 6a9de2d..f6f08cf 100644 --- a/retrieval_eval.py +++ b/retrieval_eval.py @@ -59,27 +59,27 @@ def retrieval_eval(): ) recall = 0 - + for batch in tqdm(test_loader): - if cfg.retrieval.mode=="full_metadata": + if cfg.retrieval.mode == "full_metadata": _, img_overhead, label, *_ = batch - else: + else: _, img_overhead, label = batch z = 0 running_val = 0 running_label = 0 for batch2 in tqdm(test_loader): - if cfg.retrieval.mode=="full_metadata": + if cfg.retrieval.mode == "full_metadata": img_ground, _, label2, geoloc, date = batch2 ground_embeddings, overhead_embeddings = model.forward_features( - img_ground.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda() - ) + img_ground.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda() + ) else: img_ground, _, label2 = batch2 ground_embeddings, overhead_embeddings = model.forward_features( - img_ground.cuda(), img_overhead.cuda() - ) - + img_ground.cuda(), img_overhead.cuda() + ) + similarity = torch.einsum( "ij,kj->ik", ground_embeddings, overhead_embeddings ) @@ -92,11 +92,11 @@ def retrieval_eval(): running_label = torch.cat((running_label, label2), dim=0) if cfg.retrieval.model_type == "CVEMAE": _, ind = torch.topk(running_val, 10, dim=0) - + # Hierarchical Retrieval elif cfg.retrieval.model_type == "CVMMAE": _, ind = torch.topk(running_val, 50, dim=0) - if cfg.retrieval.mode=="full_metadata": + if cfg.retrieval.mode == "full_metadata": img_ground, _, label2, geoloc, date = test_dataset[ind] else: img_ground, _, label2 = test_dataset[ind] @@ -105,15 +105,20 @@ def retrieval_eval(): for i in range(50): img_ground_rolled = torch.roll(img_ground, i, 0) idx_rolled = torch.roll(idx, i, 0) - if cfg.retrieval.mode=="full_metadata": + if cfg.retrieval.mode == "full_metadata": _, scores = model_filter.forward_features( - img_ground_rolled.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda() + img_ground_rolled.cuda(), + img_overhead.cuda(), + geoloc.cuda(), + date.cuda(), ) else: _, scores = model_filter.forward_features( img_ground_rolled.cuda(), img_overhead.cuda() ) - similarity[idx_rolled, torch.arange(50)] = scores.squeeze(0).detach().cpu() + similarity[idx_rolled, torch.arange(50)] = ( + scores.squeeze(0).detach().cpu() + ) _, ind = torch.topk(similarity, 10, dim=0) running_label = label2 @@ -122,4 +127,3 @@ def retrieval_eval(): [1 if label[i] in preds[:, i] else 0 for i in range(label.shape[0])] ) print(f"Current Recall Score: {recall/len(test_dataset)}") -