-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader_cloudcatalogue.py
executable file
·92 lines (68 loc) · 2.49 KB
/
dataloader_cloudcatalogue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pathlib
import numpy as np
import pandas as pd
import torch
import xarray as xr
def cloudcatalogue_find(root, dataset):
"""
Retrieve the file paths for CloudCatalogue dataset.
Args:
root (pathlib.Path): The root directory of the dataset.
dataset (pd.DataFrame): The dataset containing scene information.
Returns:
tuple: A tuple containing the X and y file paths.
"""
Xroot = root / "subscenes"
yroot = root / "masks"
Xfiles = []
yfiles = []
for item in range(len(dataset)):
if dataset.shadows_marked[item] == 1:
Xfiles.append(Xroot / (dataset.scene[item] + ".npy"))
yfiles.append(yroot / (dataset.scene[item] + ".npy"))
return Xfiles, yfiles
class CloudDataset(torch.utils.data.Dataset):
def __init__(self, files):
"""
CloudCatalogue dataset class for loading and preprocessing data.
Args:
root (pathlib.Path): The root directory of the dataset.
dataset (pd.DataFrame): The dataset containing scene information.
"""
self.cloudcataloguefiles = files
self.X = self.cloudcataloguefiles[0]
self.y = self.cloudcataloguefiles[1]
def __len__(self):
"""
Get the length of the dataset.
Returns:
int: The number of samples in the dataset.
"""
return len(self.X)
def __getitem__(self, idx):
"""
Get a sample from the dataset.
Args:
idx (int): The index of the sample.
Returns:
tuple: A tuple containing the input and target data.
"""
Xfile = self.X[idx]
yfile = self.y[idx]
# Add padding to the input data
# We need to pad to run the segmentation model
X = np.load(Xfile).transpose(2, 0, 1)
X = np.pad(X, ((0, 0), (1, 1), (1, 1)), "constant", constant_values=0)
# We convert the mask to a single channel:
# 0: no cloud, 1: cloud, 2: shadow
y = np.load(yfile).transpose(2, 0, 1)
y = y[0] * 0 + y[1] * 1 + y[2] * 2
y = np.pad(y, ((1, 1), (1, 1)), "constant")
# From numpy to torch
X = torch.from_numpy(X).type(torch.float)
y = torch.from_numpy(y).type(torch.long)
return X, y
# Read netcdf file
root = pathlib.Path("/media/csaybar/2F9A60C90A2CC0FB/IGARSS2023/cloudcatalogue")
dataset = pd.read_csv(root / "classification_tags.csv")
cloudcatalogue_files = CloudDataset(cloudcatalogue_find(root, dataset))