-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
regression.py
361 lines (287 loc) · 10.2 KB
/
regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Any
import torch
from pytorch_lightning.metrics.functional.regression import (
mae,
mse,
psnr,
rmse,
rmsle,
ssim
)
from pytorch_lightning.metrics.metric import Metric
class MSE(Metric):
"""
Computes the mean squared loss.
Example:
>>> pred = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = MSE()
>>> metric(pred, target)
tensor(0.2500)
"""
def __init__(
self,
reduction: str = 'elementwise_mean',
):
"""
Args:
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='mse')
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the mse loss.
"""
return mse(pred, target, return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
sse, n = output['squared_error'], output['n_observations']
return sse / n
class RMSE(Metric):
"""
Computes the root mean squared loss.
Example:
>>> pred = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = RMSE()
>>> metric(pred, target)
tensor(0.5000)
"""
def __init__(
self,
reduction: str = 'elementwise_mean',
):
"""
Args:
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='rmse')
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the rmse loss.
"""
return rmse(pred, target, reduction='none', return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
""" Squaring needs to happend after ddp sync """
sse, n = output['squared_error'], output['n_observations']
return torch.sqrt(sse / n)
class MAE(Metric):
"""
Computes the mean absolute loss or L1-loss.
Example:
>>> pred = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = MAE()
>>> metric(pred, target)
tensor(0.2500)
"""
def __init__(
self,
reduction: str = 'elementwise_mean',
):
"""
Args:
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='mae')
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the mae loss.
"""
return mae(pred, target, return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
sae, n = output['absolute_error'], output['n_observations']
return sae / n
class RMSLE(Metric):
"""
Computes the root mean squared log loss.
Example:
>>> pred = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = RMSLE()
>>> metric(pred, target)
tensor(0.1438)
"""
def __init__(
self,
reduction: str = 'elementwise_mean',
):
"""
Args:
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='rmsle')
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the rmsle loss.
"""
return mse(torch.log(pred + 1), torch.log(target + 1),
self.reduction, return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
""" Squaring needs to happend after ddp sync """
sse, n = output['squared_error'], output['n_observations']
return torch.sqrt(sse / n)
class PSNR(Metric):
"""
Computes the peak signal-to-noise ratio
Example:
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor(2.5527)
"""
def __init__(
self,
data_range: float = None,
base: int = 10,
reduction: str = 'elementwise_mean'
):
"""
Args:
data_range: the range of the data. If None, it is determined from the data (max - min)
base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='psnr')
self.data_range = data_range
self.base = float(base)
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with psnr score.
"""
return psnr(pred, target, self.data_range, self.base, self.reduction, return_state=True)
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
""" Special aggregation function as the data range needs to be correctly synced """
if len(tensors) == 1:
tensors = tensors[0]
output = {'data_range': torch.stack([t for t in tensors['data_range']]).max()}
output.update({k: torch.stack([t for t in tensors[k]]).sum(0) for k in tensors.keys() if k != 'data_range'})
return output
output = {'data_range': torch.stack([tensor['data_range'] for tensor in tensors]).max()}
output.update({k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys() if k != 'data_range'})
return output
@staticmethod
def compute(self, data: Any, output: Any):
"""
Compute final value based on the synced data_range, sum of squared errors
and number of samples.
Args:
data: input to forward method
output: output from the `aggregate` hook
Returns:
final metric value
"""
sse, n, data_range = output['sum_squared_error'], output['n_obs'], output['data_range']
psnr_base_e = 2 * torch.log(data_range) - torch.log(sse / n)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(self.base)))
return psnr
class SSIM(Metric):
"""
Computes Structual Similarity Index Measure
Example:
>>> pred = torch.rand([16, 1, 16, 16])
>>> target = pred * 0.75
>>> metric = SSIM()
>>> metric(pred, target)
tensor(0.9219)
"""
def __init__(
self,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: float = None,
k1: float = 0.01,
k2: float = 0.03
):
"""
Args:
kernel_size: Size of the gaussian kernel (default: (11, 11))
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03
"""
super().__init__(name="ssim")
self.kernel_size = kernel_size
self.sigma = sigma
self.reduction = reduction
self.data_range = data_range
self.k1 = k1
self.k2 = k2
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: Estimated image
target: Ground truth image
Return:
A Tensor with SSIM score.
"""
return ssim(pred, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2)