-
Notifications
You must be signed in to change notification settings - Fork 2
/
transforms.py
134 lines (104 loc) · 4.47 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import random
import numpy as np
import torch
class Identity(object):
def __call__(self, sample):
return sample
class RandomSamplePixels(object):
"""Randomly draw num_pixels from the available pixels in sample.
If the total number of pixels is less than num_pixels, one arbitrary pixel is repeated.
The valid_pixels keeps track of true and repeated pixels.
Args:
num_pixels (int): Number of pixels to sample.
"""
def __init__(self, num_pixels):
self.num_pixels = num_pixels
def __call__(self, sample):
pixels = sample['pixels']
T, C, S = pixels.shape
if S > self.num_pixels:
indices = random.sample(range(S), self.num_pixels)
x = pixels[:, :, indices]
valid_pixels = np.ones(self.num_pixels)
elif S < self.num_pixels:
x = np.zeros((T, C, self.num_pixels))
x[..., :S] = pixels
x[..., S:] = np.stack([x[:, :, 0] for _ in range(S, self.num_pixels)], axis=-1)
valid_pixels = np.array([1 for _ in range(S)] + [0 for _ in range(S, self.num_pixels)])
else:
x = pixels
valid_pixels = np.ones(self.num_pixels)
# Repeat valid_pixels across time
valid_pixels = np.repeat(valid_pixels[np.newaxis].astype(np.float32), x.shape[0], axis=0)
sample['pixels'] = x
sample['valid_pixels'] = valid_pixels
return sample
class RandomSampleTimeSteps(object):
"""Randomly draw seq_length time steps to fix the time dimension.
Args:
seq_length (int): Number of time steps to sample. If -1, do nothing.
"""
def __init__(self, seq_length):
self.seq_length = seq_length
def __call__(self, sample):
if self.seq_length == -1:
return sample
pixels, date_positions, valid_pixels, gdd = sample['pixels'], sample['positions'], sample['valid_pixels'], sample['gdd']
t = pixels.shape[0]
if t > self.seq_length:
indices = sorted(random.sample(range(t), self.seq_length))
sample['pixels'] = pixels[indices]
sample['positions'] = date_positions[indices]
sample['valid_pixels'] = valid_pixels[indices]
sample['gdd'] = gdd[indices]
elif t == self.seq_length:
return sample
else:
raise NotImplementedError
return sample
class ShiftAug(object):
"""Randomly shift date positions
Source: https://github.com/jnyborg/timematch/blob/main/transforms.py
Paper: https://arxiv.org/abs/2111.02682
Args:
max_shift (int): Maximum possible temporal shift
p (float): Probability to apply
"""
def __init__(self, max_shift=60, p=1.0):
self.max_shift = max_shift
self.p = p
def __call__(self, sample):
if random.random() < self.p:
shift = random.randint(-self.max_shift, self.max_shift)
sample['positions'] = sample['positions'] + shift
return sample
class Normalize(object):
"""Normalize by rescaling pixels to [0, 1]
Args:
max_pixel_value (int): Max value of pixels to move pixels to [0, 1]
"""
def __init__(self, max_pixel_value=65535):
self.max_pixel_value = max_pixel_value
# approximate max values
max_parcel_box_m = 10000
max_perimeter = max_parcel_box_m * 4
max_area = max_parcel_box_m ** 2
max_perimeter_area_ratio = max_perimeter
max_cover_ratio = 1.0
self.max_extra_values = np.array([max_perimeter, max_area, max_perimeter_area_ratio, max_cover_ratio])
def __call__(self, sample):
sample['pixels'] = np.clip(sample['pixels'], 0, self.max_pixel_value).astype(np.float32) / self.max_pixel_value
if 'extra' in sample:
sample['extra'] = sample['extra'].astype(np.float32) / self.max_extra_values
return sample
class ToTensor(object):
def __call__(self, sample):
sample['pixels'] = torch.from_numpy(sample['pixels'].astype(np.float32))
sample['valid_pixels'] = torch.from_numpy(sample['valid_pixels'].astype(np.float32))
sample['positions'] = torch.from_numpy(sample['positions'].astype(np.long))
sample['gdd'] = torch.from_numpy(sample['gdd'].astype(np.long))
if 'extra' in sample:
sample['extra'] = torch.from_numpy(sample['extra'].astype(np.float32))
if isinstance(sample['label'], int):
sample['label'] = torch.tensor(sample['label']).long()
return sample