Skip to content

Commit

Permalink
Polished
Browse files Browse the repository at this point in the history
  • Loading branch information
Vishu26 committed Oct 28, 2023
1 parent af60c5d commit b5867fd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 37 deletions.
52 changes: 30 additions & 22 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,21 @@ def __init__(self, dataset, label, val=False):
)
else:
self.transform_ground = transforms.Compose(
[transforms.Resize(cfg.pretrain.ground.img_size), transforms.ToTensor(), transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
[
transforms.Resize(cfg.pretrain.ground.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
self.transform_overhead = transforms.Compose(
[
transforms.Resize(cfg.pretrain.overhead.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)

Expand Down Expand Up @@ -298,27 +304,27 @@ def __init__(self, dataset, label, ids):
self.images = self.images[self.idx]
self.ids = ids
self.transform_ground = transforms.Compose(
[
transforms.Resize(cfg.pretrain.ground.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
[
transforms.Resize(cfg.pretrain.ground.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
self.transform_overhead = transforms.Compose(
[
transforms.Resize(cfg.pretrain.overhead.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
[
transforms.Resize(cfg.pretrain.overhead.img_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)

def __len__(self):
return len(self.ids)

def __getitem__(self, idx):
img_path = self.images[self.ids[idx]]["file_name"]
lat = self.images[self.ids[idx]]["latitude"]
Expand All @@ -336,7 +342,9 @@ def __getitem__(self, idx):
)
img_ground = Image.open(os.path.join("data", img_path))
img_ground = self.transform_ground(img_ground)
img_overhead = Image.open(f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg")
img_overhead = Image.open(
f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg"
)
img_overhead = self.transform_overhead(img_overhead)
if cfg.pretrain.train.mode == "no_metadata":
return img_ground, img_overhead
Expand Down
34 changes: 19 additions & 15 deletions retrieval_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,27 @@ def retrieval_eval():
)

recall = 0

for batch in tqdm(test_loader):
if cfg.retrieval.mode=="full_metadata":
if cfg.retrieval.mode == "full_metadata":
_, img_overhead, label, *_ = batch
else:
else:
_, img_overhead, label = batch
z = 0
running_val = 0
running_label = 0
for batch2 in tqdm(test_loader):
if cfg.retrieval.mode=="full_metadata":
if cfg.retrieval.mode == "full_metadata":
img_ground, _, label2, geoloc, date = batch2
ground_embeddings, overhead_embeddings = model.forward_features(
img_ground.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda()
)
img_ground.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda()
)
else:
img_ground, _, label2 = batch2
ground_embeddings, overhead_embeddings = model.forward_features(
img_ground.cuda(), img_overhead.cuda()
)
img_ground.cuda(), img_overhead.cuda()
)

similarity = torch.einsum(
"ij,kj->ik", ground_embeddings, overhead_embeddings
)
Expand All @@ -92,11 +92,11 @@ def retrieval_eval():
running_label = torch.cat((running_label, label2), dim=0)
if cfg.retrieval.model_type == "CVEMAE":
_, ind = torch.topk(running_val, 10, dim=0)

# Hierarchical Retrieval
elif cfg.retrieval.model_type == "CVMMAE":
_, ind = torch.topk(running_val, 50, dim=0)
if cfg.retrieval.mode=="full_metadata":
if cfg.retrieval.mode == "full_metadata":
img_ground, _, label2, geoloc, date = test_dataset[ind]
else:
img_ground, _, label2 = test_dataset[ind]
Expand All @@ -105,15 +105,20 @@ def retrieval_eval():
for i in range(50):
img_ground_rolled = torch.roll(img_ground, i, 0)
idx_rolled = torch.roll(idx, i, 0)
if cfg.retrieval.mode=="full_metadata":
if cfg.retrieval.mode == "full_metadata":
_, scores = model_filter.forward_features(
img_ground_rolled.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda()
img_ground_rolled.cuda(),
img_overhead.cuda(),
geoloc.cuda(),
date.cuda(),
)
else:
_, scores = model_filter.forward_features(
img_ground_rolled.cuda(), img_overhead.cuda()
)
similarity[idx_rolled, torch.arange(50)] = scores.squeeze(0).detach().cpu()
similarity[idx_rolled, torch.arange(50)] = (
scores.squeeze(0).detach().cpu()
)
_, ind = torch.topk(similarity, 10, dim=0)
running_label = label2

Expand All @@ -122,4 +127,3 @@ def retrieval_eval():
[1 if label[i] in preds[:, i] else 0 for i in range(label.shape[0])]
)
print(f"Current Recall Score: {recall/len(test_dataset)}")

0 comments on commit b5867fd

Please sign in to comment.