-
Notifications
You must be signed in to change notification settings - Fork 0
/
RemoteSensingDataset.py
65 lines (45 loc) · 2.07 KB
/
RemoteSensingDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from other_imports import *
from configs import *
class BEN_Dataset(torch.utils.data.Dataset):
def __init__ (self, df, transforms=None):
super(BEN_Dataset, self).__init__()
self.df = df
self.transforms = transforms
self.data = [(row[0], np.array(list(row[1:]))) for _, row in self.df.iterrows()]
def __len__(self):
return len(self.df)
def load_patches(self, patch_dir):
_OPTICAL_MAX_VALUE = 2000.
bands = [np.asarray(
Image.open("../BigEarthNet-v1.0" + "/" + f"{patch_dir}" + "/" + f"{patch_dir}_B{band}.tif"), dtype=np.uint16) for band in ["04", "03", "02"]]
stacked_arr = np.stack(bands, axis=-1)
image = stacked_arr /_OPTICAL_MAX_VALUE * 255.0
image = np.clip(image, 0, 255).astype(np.uint8)
return image
def __getitem__(self, index):
image_path, label = self.data[index]
image = self.load_patches(image_path)
if self.transforms:
image = self.transforms(image=image)["image"] / 255.0
return {"image": image,
"label": torch.as_tensor(label)}
class RSDataset(torch.utils.data.Dataset):
def __init__ (self, df, image_size, transforms=None):
super(RSDataset, self).__init__()
self.df = df
self.image_size = image_size
self.transforms = transforms
self.data = [(row[0], row[1:]) for _, row in self.df.iterrows()]
def __len__(self):
return len(self.df)
def __getitem__ (self, index):
image_path, label = self.data[index]
label = np.array(list(label.values))
label = torch.from_numpy(label)
image = Image.open(image_path).convert('RGB')
image = np.array(image)
image = cv2.resize(image, (self.image_size, self.image_size), interpolation = cv2.INTER_CUBIC)
if self.transforms:
image = self.transforms(image=image)["image"] / 255.0
return {"image": image,
"label": torch.as_tensor(label)}