Skip to content

Polished

Polished #12

Triggered via push October 28, 2023 05:42
Status Success
Total duration 19s
Artifacts

linting.yml

on: push
Run linters
10s
Run linters
Fit to window
Zoom out
Zoom in

Annotations

5 errors and 1 warning
/home/runner/work/BirdSAT/BirdSAT/retrieval_eval.py#L57
shuffle=False, num_workers=cfg.retrieval.num_workers, ) recall = 0 - + for batch in tqdm(test_loader): - if cfg.retrieval.mode=="full_metadata": + if cfg.retrieval.mode == "full_metadata": _, img_overhead, label, *_ = batch - else: + else: _, img_overhead, label = batch z = 0 running_val = 0 running_label = 0 for batch2 in tqdm(test_loader): - if cfg.retrieval.mode=="full_metadata": + if cfg.retrieval.mode == "full_metadata": img_ground, _, label2, geoloc, date = batch2 ground_embeddings, overhead_embeddings = model.forward_features( - img_ground.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda() - ) + img_ground.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda() + ) else: img_ground, _, label2 = batch2 ground_embeddings, overhead_embeddings = model.forward_features( - img_ground.cuda(), img_overhead.cuda() - ) - + img_ground.cuda(), img_overhead.cuda() + ) + similarity = torch.einsum( "ij,kj->ik", ground_embeddings, overhead_embeddings ) if z == 0: running_val = similarity.detach().cpu()
/home/runner/work/BirdSAT/BirdSAT/retrieval_eval.py#L90
else: running_val = torch.cat((running_val, similarity.detach().cpu()), dim=0) running_label = torch.cat((running_label, label2), dim=0) if cfg.retrieval.model_type == "CVEMAE": _, ind = torch.topk(running_val, 10, dim=0) - + # Hierarchical Retrieval elif cfg.retrieval.model_type == "CVMMAE": _, ind = torch.topk(running_val, 50, dim=0) - if cfg.retrieval.mode=="full_metadata": + if cfg.retrieval.mode == "full_metadata": img_ground, _, label2, geoloc, date = test_dataset[ind] else: img_ground, _, label2 = test_dataset[ind] similarity = torch.zeros((50, 50)) idx = torch.arange(50) for i in range(50): img_ground_rolled = torch.roll(img_ground, i, 0) idx_rolled = torch.roll(idx, i, 0) - if cfg.retrieval.mode=="full_metadata": + if cfg.retrieval.mode == "full_metadata": _, scores = model_filter.forward_features( - img_ground_rolled.cuda(), img_overhead.cuda(), geoloc.cuda(), date.cuda() + img_ground_rolled.cuda(), + img_overhead.cuda(), + geoloc.cuda(), + date.cuda(), ) else: _, scores = model_filter.forward_features( img_ground_rolled.cuda(), img_overhead.cuda() ) - similarity[idx_rolled, torch.arange(50)] = scores.squeeze(0).detach().cpu() + similarity[idx_rolled, torch.arange(50)] = ( + scores.squeeze(0).detach().cpu() + ) _, ind = torch.topk(similarity, 10, dim=0) running_label = label2 preds = running_label[ind] recall += sum( [1 if label[i] in preds[:, i] else 0 for i in range(label.shape[0])] ) print(f"Current Recall Score: {recall/len(test_dataset)}") -
/home/runner/work/BirdSAT/BirdSAT/datasets.py#L148
), ] ) else: self.transform_ground = transforms.Compose( - [transforms.Resize(cfg.pretrain.ground.img_size), transforms.ToTensor(), transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] + [ + transforms.Resize(cfg.pretrain.ground.img_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] ) self.transform_overhead = transforms.Compose( [ transforms.Resize(cfg.pretrain.overhead.img_size), transforms.ToTensor(), transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), ] ) def __len__(self): return len(self.images)
/home/runner/work/BirdSAT/BirdSAT/datasets.py#L296
self.images = np.array(self.dataset["images"]) self.idx = np.array(label.iloc[:, 1]).astype(int) self.images = self.images[self.idx] self.ids = ids self.transform_ground = transforms.Compose( - [ - transforms.Resize(cfg.pretrain.ground.img_size), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), - ] - ) + [ + transforms.Resize(cfg.pretrain.ground.img_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) self.transform_overhead = transforms.Compose( - [ - transforms.Resize(cfg.pretrain.overhead.img_size), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ), - ] - ) - + [ + transforms.Resize(cfg.pretrain.overhead.img_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + def __len__(self): return len(self.ids) - + def __getitem__(self, idx): img_path = self.images[self.ids[idx]]["file_name"] lat = self.images[self.ids[idx]]["latitude"] lon = self.images[self.ids[idx]]["longitude"] date = self.images[self.ids[idx]]["date"].split(" ")[0]
/home/runner/work/BirdSAT/BirdSAT/datasets.py#L334
np.cos(2 * np.pi * day / 31), ] ) img_ground = Image.open(os.path.join("data", img_path)) img_ground = self.transform_ground(img_ground) - img_overhead = Image.open(f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg") + img_overhead = Image.open( + f"data/val_overhead/images_sentinel/{self.ids[idx]}.jpeg" + ) img_overhead = self.transform_overhead(img_overhead) if cfg.pretrain.train.mode == "no_metadata": return img_ground, img_overhead else: return (
Run linters
The following actions uses node12 which is deprecated and will be forced to run on node16: actions/checkout@v2, actions/setup-python@v1. For more info: https://github.blog/changelog/2023-06-13-github-actions-all-actions-will-run-on-node16-instead-of-node12-by-default/