-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_indices.py
41 lines (33 loc) · 1.55 KB
/
get_indices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import pandas as pd
import segmentation_models_pytorch as smp
import timm
import torch
from dataloader_cloudcatalogue import cloudcatalogue_files
from dataloader_cloudsen12 import testing_data, training_data, validation_data
from dataloader_kappaset import kappaset_files
from utils import cloudcatalogue_table, cloudsen12_table, kappaset_table
# Hardness index model
HImodel = timm.create_model("resnet10t", pretrained=True, num_classes=1, in_chans=13)
HImodel.load_state_dict(torch.load("weights/resnet10t.pt"))
HImodel.eval()
HImodel.cuda()
# Trustworthiness index model
TImodel = smp.Unet(encoder_name="mobilenet_v2", in_channels=13, classes=4)
TImodel.load_state_dict(torch.load("weights/UNetMobV2.pt"))
TImodel.eval()
TImodel.cuda()
# CloudSEN12 table with TI and HI indices
train_cloudsen12_db = cloudsen12_table(training_data, HImodel, TImodel)
val_cloudsen12_db = cloudsen12_table(validation_data, HImodel, TImodel)
test_cloudsen12_db = cloudsen12_table(testing_data, HImodel, TImodel)
cloudsen12_db = pd.concat(
[train_cloudsen12_db, val_cloudsen12_db, test_cloudsen12_db], axis=0
)
cloudsen12_db.to_csv("results/cloudsen12_indices.csv", index=False)
# Kappaset table with TI and HI indices
train_kappaset_db = kappaset_table(kappaset_files, HImodel, TImodel)
train_kappaset_db.to_csv("results/kappaset_indices.csv", index=False)
# CloudCatalogue table with TI and HI indices
train_cloudcatalogue_db = cloudcatalogue_table(cloudcatalogue_files, HImodel, TImodel)
train_cloudcatalogue_db.to_csv("results/cloudcatalogue_indices.csv", index=False)