Polished #12
Annotations
5 errors and 1 warning
/home/runner/work/BirdSAT/BirdSAT/retrieval_eval.py#L57
shuffle=False,
num_workers=cfg.retrieval.num_workers,
)
recall = 0
-
+
for batch in tqdm(test_loader):
- if cfg.retrieval.mode=="full_metadata":
+ if cfg.retrieval.mode == "full_metadata":
_, img_overhead, label, *_ = batch
- else:
+ else:
_, img_overhead, label = batch
z = 0
running_val = 0
running_label = 0
for batch2 in tqdm(test_loader):
- if cfg.retrieval.mode=="full_metadata":
+ if cfg.retrieval.mode == "full_metadata":
img_ground, _, label2, geoloc, date = batch2
ground_embeddings, overhead_embeddings = model.forward_features(
- img_ground.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda()
- )
+ img_ground.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda()
+ )
else:
img_ground, _, label2 = batch2
ground_embeddings, overhead_embeddings = model.forward_features(
- img_ground.cuda(), img_overhead.cuda()
- )
-
+ img_ground.cuda(), img_overhead.cuda()
+ )
+
similarity = torch.einsum(
"ij,kj->ik", ground_embeddings, overhead_embeddings
)
if z == 0:
running_val = similarity.detach().cpu()
|
/home/runner/work/BirdSAT/BirdSAT/retrieval_eval.py#L90
else:
running_val = torch.cat((running_val, similarity.detach().cpu()), dim=0)
running_label = torch.cat((running_label, label2), dim=0)
if cfg.retrieval.model_type == "CVEMAE":
_, ind = torch.topk(running_val, 10, dim=0)
-
+
# Hierarchical Retrieval
elif cfg.retrieval.model_type == "CVMMAE":
_, ind = torch.topk(running_val, 50, dim=0)
- if cfg.retrieval.mode=="full_metadata":
+ if cfg.retrieval.mode == "full_metadata":
img_ground, _, label2, geoloc, date = test_dataset[ind]
else:
img_ground, _, label2 = test_dataset[ind]
similarity = torch.zeros((50, 50))
idx = torch.arange(50)
for i in range(50):
img_ground_rolled = torch.roll(img_ground, i, 0)
idx_rolled = torch.roll(idx, i, 0)
- if cfg.retrieval.mode=="full_metadata":
+ if cfg.retrieval.mode == "full_metadata":
_, scores = model_filter.forward_features(
- img_ground_rolled.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda()
+ img_ground_rolled.cuda(),
+ img_overhead.cuda(),
+ geoloc.cuda(),
+ date.cuda(),
)
else:
_, scores = model_filter.forward_features(
img_ground_rolled.cuda(), img_overhead.cuda()
)
- similarity[idx_rolled, torch.arange(50)] = scores.squeeze(0).detach().cpu()
+ similarity[idx_rolled, torch.arange(50)] = (
+ scores.squeeze(0).detach().cpu()
+ )
_, ind = torch.topk(similarity, 10, dim=0)
running_label = label2
preds = running_label[ind]
recall += sum(
[1 if label[i] in preds[:, i] else 0 for i in range(label.shape[0])]
)
print(f"Current Recall Score: {recall/len(test_dataset)}")
-
|
/home/runner/work/BirdSAT/BirdSAT/datasets.py#L148
),
]
)
else:
self.transform_ground = transforms.Compose(
- [transforms.Resize(cfg.pretrain.ground.img_size), transforms.ToTensor(), transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
+ [
+ transforms.Resize(cfg.pretrain.ground.img_size),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ ),
+ ]
)
self.transform_overhead = transforms.Compose(
[
transforms.Resize(cfg.pretrain.overhead.img_size),
transforms.ToTensor(),
transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ ),
]
)
def __len__(self):
return len(self.images)
|
/home/runner/work/BirdSAT/BirdSAT/datasets.py#L296
self.images = np.array(self.dataset["images"])
self.idx = np.array(label.iloc[:, 1]).astype(int)
self.images = self.images[self.idx]
self.ids = ids
self.transform_ground = transforms.Compose(
- [
- transforms.Resize(cfg.pretrain.ground.img_size),
- transforms.ToTensor(),
- transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- ),
- ]
- )
+ [
+ transforms.Resize(cfg.pretrain.ground.img_size),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ ),
+ ]
+ )
self.transform_overhead = transforms.Compose(
- [
- transforms.Resize(cfg.pretrain.overhead.img_size),
- transforms.ToTensor(),
- transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- ),
- ]
- )
-
+ [
+ transforms.Resize(cfg.pretrain.overhead.img_size),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ ),
+ ]
+ )
+
def __len__(self):
return len(self.ids)
-
+
def __getitem__(self, idx):
img_path = self.images[self.ids[idx]]["file_name"]
lat = self.images[self.ids[idx]]["latitude"]
lon = self.images[self.ids[idx]]["longitude"]
date = self.images[self.ids[idx]]["date"].split(" ")[0]
|
/home/runner/work/BirdSAT/BirdSAT/datasets.py#L334
np.cos(2 * np.pi * day / 31),
]
)
img_ground = Image.open(os.path.join("data", img_path))
img_ground = self.transform_ground(img_ground)
- img_overhead = Image.open(f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg")
+ img_overhead = Image.open(
+ f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg"
+ )
img_overhead = self.transform_overhead(img_overhead)
if cfg.pretrain.train.mode == "no_metadata":
return img_ground, img_overhead
else:
return (
|
Run linters
The following actions uses node12 which is deprecated and will be forced to run on node16: actions/checkout@v2, actions/setup-python@v1. For more info: https://github.blog/changelog/2023-06-13-github-actions-all-actions-will-run-on-node16-instead-of-node12-by-default/
|