-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict_test_plot_comparison.py
298 lines (260 loc) · 13.1 KB
/
predict_test_plot_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.transforms import Bbox
# import and set up the typeguard
from typeguard.importhook import install_import_hook
# comment these out when deploying:
install_import_hook('src.nn')
install_import_hook('src.scoring_rules')
install_import_hook('src.utils')
install_import_hook('src.parsers')
install_import_hook('src.calibration')
install_import_hook('src.weatherbench_utils')
install_import_hook('src.unet_utils')
from src.nn import ConditionalGenerativeModel, createGenerativeFCNN, InputTargetDataset, \
UNet2D, DiscardWindowSizeDim, get_predictions_and_target, createGenerativeGRUNN
from src.utils import load_net, def_loader_kwargs
from src.parsers import parser_predict, nonlinearities_dict, setup, default_model_folder, default_root_folder
from src.weatherbench_utils import load_weatherbench_data
# --- parser ---
parser = parser_predict()
args = parser.parse_args()
model = args.model
# method = args.method
# scoring_rule = args.scoring_rule
kernel = args.kernel
patched = args.patched
base_measure = args.base_measure
root_folder = args.root_folder
model_folder = args.model_folder
datasets_folder = args.datasets_folder
weatherbench_data_folder = args.weatherbench_data_folder
weatherbench_small = args.weatherbench_small
unet_noise_method = args.unet_noise_method
unet_large = args.unet_large
# lr = args.lr
# lr_c = args.lr_c
batch_size = args.batch_size
no_early_stop = args.no_early_stop
critic_steps_every_generator_step = args.critic_steps_every_generator_step
save_plots = not args.no_save_plots
cuda = args.cuda
load_all_data_GPU = args.load_all_data_GPU
training_ensemble_size = args.training_ensemble_size
prediction_ensemble_size = args.prediction_ensemble_size
nonlinearity = args.nonlinearity
data_size = args.data_size
auxiliary_var_size = args.auxiliary_var_size
seed = args.seed
plot_start_timestep = args.plot_start_timestep
plot_end_timestep = args.plot_end_timestep
gamma = args.gamma_kernel_score
gamma_patched = args.gamma_kernel_score_patched
patch_size = args.patch_size
no_RNN = args.no_RNN
hidden_size_rnn = args.hidden_size_rnn
save_pdf = True
plot_start_timestep = 0
plot_end_timestep = 30
compute_patched = model in ["lorenz96", ]
if model == "lorenz":
# define the 3 things which we consider for that
method1 = "SR"
scoring_rule1 = "Energy"
lr1 = 0.01
lr_c1 = None
hidden_size_rnn_1 = 8
critic_steps_every_generator_step1 = 1
method2 = "GAN"
scoring_rule2 = None
lr2 = 0.0001
lr_c2 = 0.001
hidden_size_rnn_2 = 8
critic_steps_every_generator_step2 = 1
method3 = "WGAN_GP"
scoring_rule3 = None
lr3 = 0.0003
lr_c3 = 0.03
hidden_size_rnn_3 = 8
critic_steps_every_generator_step3 = 5
methods_list = ["Energy", "GAN", "WGAN-GP"]
elif model == "lorenz96":
# define the 3 things which we consider for that
method1 = "SR"
scoring_rule1 = "EnergyKernel"
lr1 = 0.001
lr_c1 = None
hidden_size_rnn_1 = 32
critic_steps_every_generator_step1 = 1
method2 = "GAN"
scoring_rule2 = None
lr2 = 0.0001
lr_c2 = 0.001
hidden_size_rnn_2 = 64
critic_steps_every_generator_step2 = 1
method3 = "WGAN_GP"
scoring_rule3 = None
lr3 = 0.0001
lr_c3 = 0.01
hidden_size_rnn_3 = 64
critic_steps_every_generator_step3 = 5
methods_list = ["Energy-Kernel", "GAN", "WGAN-GP"]
else:
raise NotImplementedError
model_is_weatherbench = model == "WeatherBench"
nn_model = "unet" if model_is_weatherbench else ("fcnn" if no_RNN else "rnn")
datasets_folder, nets_folder1, data_size, auxiliary_var_size, name_postfix1, unet_depths, patch_size, method_is_gan1, hidden_size_rnn_1 = \
setup(model, root_folder, model_folder, args.datasets_folder, data_size, method1, scoring_rule1, kernel, patched,
patch_size, training_ensemble_size, auxiliary_var_size, critic_steps_every_generator_step1, base_measure, lr1,
lr_c1, batch_size, no_early_stop, unet_noise_method, unet_large, nn_model, hidden_size_rnn_1)
datasets_folder, nets_folder2, data_size, auxiliary_var_size, name_postfix2, unet_depths, patch_size, method_is_gan2, hidden_size_rnn_2 = \
setup(model, root_folder, model_folder, args.datasets_folder, data_size, method2, scoring_rule2, kernel, patched,
patch_size, training_ensemble_size, auxiliary_var_size, critic_steps_every_generator_step2, base_measure, lr2,
lr_c2, batch_size, no_early_stop, unet_noise_method, unet_large, nn_model, hidden_size_rnn_2)
datasets_folder, nets_folder3, data_size, auxiliary_var_size, name_postfix3, unet_depths, patch_size, method_is_gan3, hidden_size_rnn_3 = \
setup(model, root_folder, model_folder, args.datasets_folder, data_size, method3, scoring_rule3, kernel, patched,
patch_size, training_ensemble_size, auxiliary_var_size, critic_steps_every_generator_step3, base_measure, lr3,
lr_c3, batch_size, no_early_stop, unet_noise_method, unet_large, nn_model, hidden_size_rnn_3)
# --- data handling ---
if not model_is_weatherbench:
input_data_test = torch.load(datasets_folder + "test_x.pty")
target_data_test = torch.load(datasets_folder + "test_y.pty")
input_data_val = torch.load(datasets_folder + "val_x.pty")
target_data_val = torch.load(datasets_folder + "val_y.pty")
window_size = input_data_test.shape[1]
# create the test loaders; these are unused for the moment.
dataset_val = InputTargetDataset(input_data_val, target_data_val, "cuda" if cuda and load_all_data_GPU else "cpu")
dataset_test = InputTargetDataset(input_data_test, target_data_test,
"cuda" if cuda and load_all_data_GPU else "cpu")
else:
print("Load weatherbench dataset...")
dataset_train, dataset_val, dataset_test = load_weatherbench_data(weatherbench_data_folder, cuda, load_all_data_GPU,
return_test=True,
weatherbench_small=weatherbench_small)
print("Loaded")
print("Validation set size:", len(dataset_val))
print("Test set size:", len(dataset_test))
loader_kwargs = def_loader_kwargs(cuda, load_all_data_GPU)
# loader_kwargs.update(loader_kwargs_2) # if you want to add other loader arguments
data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, **loader_kwargs)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, **loader_kwargs)
# --- networks ---
nets_list = []
nets_folder_list = [nets_folder1, nets_folder2, nets_folder3]
name_postfix_list = [name_postfix1, name_postfix2, name_postfix3]
gru_hidden_size_list = [hidden_size_rnn_1, hidden_size_rnn_2, hidden_size_rnn_3]
for i in range(3):
wrap_net = True
# create generative net:
if nn_model == "fcnn":
input_size = window_size * data_size + auxiliary_var_size
output_size = data_size
hidden_sizes_list = [int(input_size * 1.5), int(input_size * 3), int(input_size * 3),
int(input_size * 0.75 + output_size * 3), int(output_size * 5)]
inner_net = createGenerativeFCNN(input_size=input_size, output_size=output_size, hidden_sizes=hidden_sizes_list,
nonlinearity=nonlinearities_dict[nonlinearity])()
elif nn_model == "rnn":
output_size = data_size
gru_layers = 1
inner_net = createGenerativeGRUNN(data_size=data_size, gru_hidden_size=gru_hidden_size_list[i],
noise_size=auxiliary_var_size,
output_size=output_size, hidden_sizes=None, gru_layers=gru_layers,
nonlinearity=nonlinearities_dict[nonlinearity])()
elif nn_model == "unet":
# select the noise method here:
inner_net = UNet2D(in_channels=data_size[0], out_channels=1, noise_method=unet_noise_method,
number_generations_per_forward_call=prediction_ensemble_size, conv_depths=unet_depths)
if unet_noise_method in ["sum", "concat"]:
# here we overwrite the auxiliary_var_size above, as there is a precise constraint
downsampling_factor, n_channels = inner_net.calculate_downsampling_factor()
if weatherbench_small:
auxiliary_var_size = torch.Size(
[n_channels, 16 // downsampling_factor, 16 // downsampling_factor])
else:
auxiliary_var_size = torch.Size(
[n_channels, data_size[1] // downsampling_factor, data_size[2] // downsampling_factor])
elif unet_noise_method == "dropout":
wrap_net = False # do not wrap in the conditional generative model
if wrap_net:
net = load_net(nets_folder_list[i] + f"net{name_postfix_list[i]}.pth", ConditionalGenerativeModel, inner_net,
size_auxiliary_variable=auxiliary_var_size, base_measure=base_measure,
number_generations_per_forward_call=prediction_ensemble_size, seed=seed + 1)
nets_list.append(net)
else:
net = load_net(nets_folder_list[i] + f"net{name_postfix_list[i]}.pth", DiscardWindowSizeDim, inner_net)
nets_list.append(net)
if cuda:
net.cuda()
# --- predictions ---
# predict all the different elements of the test set and create plots.
# can directly feed through the whole test set for now; if it does not work well then, I will batch it.
predictions_val_list = []
predictions_test_list = []
for i in range(3):
with torch.no_grad():
if model_is_weatherbench:
# shape (n_val, ensemble_size, lon, lat, n_fields)
predictions_val, target_data_val = get_predictions_and_target(data_loader_val, nets_list[i], cuda)
predictions_test, target_data_test = get_predictions_and_target(data_loader_test, nets_list[i], cuda)
# _map is with the original shape. The following instead is flattened:
predictions_val = predictions_val.flatten(2, -1)
target_data_val = target_data_val.flatten(1, -1)
predictions_test = predictions_test.flatten(2, -1)
target_data_test = target_data_test.flatten(1, -1)
else:
predictions_val = nets_list[i](input_data_val) # shape (n_val, ensemble_size, data_size)
predictions_test = nets_list[i](input_data_test) # shape (n_test, ensemble_size, data_size)
predictions_val_list.append(predictions_val.cpu().detach().numpy())
predictions_test_list.append(predictions_test.cpu().detach().numpy())
# -- plots --
with torch.no_grad():
if model_is_weatherbench:
# we visualize only the first 8 variables.
variable_list = np.linspace(0, target_data_test.shape[-1] - 1, 8, dtype=int)
predictions_test = predictions_test[:, :, variable_list]
target_data_test = target_data_test[:, variable_list]
target_data_test_for_plot = target_data_test.cpu()
time_vec = torch.arange(len(predictions_test)).cpu()
data_size = 1
if model == "lorenz":
var_name = r"$y$"
elif model == "WeatherBench":
# todo write here the correct lon and lat coordinates!
var_name = r"$x_{}$".format(1)
else:
var_name = r"$x_{}$".format(1)
# predictions: median and 99% quantile region
fig, ax = plt.subplots(nrows=data_size, ncols=1, sharex="col", figsize=(8, 3.5) if data_size == 1 else None)
label_size = 13
# add the target values:
ax.plot(time_vec[plot_start_timestep:plot_end_timestep],
target_data_test_for_plot[plot_start_timestep:plot_end_timestep, 0], ls="--", color="black",
label="True")
size = 99
for i in range(3):
predictions_median = np.median(predictions_test_list[i], axis=1)
predictions_lower = np.percentile(predictions_test_list[i], 50 - size / 2, axis=1)
predictions_upper = np.percentile(predictions_test_list[i], 50 + size / 2, axis=1)
ax.plot(time_vec[plot_start_timestep:plot_end_timestep],
predictions_median[plot_start_timestep:plot_end_timestep, 0], ls="-", color=f"C{i}",
label=methods_list[i], alpha=0.6)
ax.fill_between(
time_vec[plot_start_timestep:plot_end_timestep], alpha=0.2, color=f"C{i}",
y1=predictions_lower[plot_start_timestep:plot_end_timestep, 0],
y2=predictions_upper[plot_start_timestep:plot_end_timestep, 0])
ax.set_ylabel(var_name, size=label_size)
ax.tick_params(axis='both', which='major', labelsize=label_size)
ax.legend(fontsize=label_size)
ax.set_xlabel(r"$t$", size=label_size)
# fig.suptitle(f"Median and {size}% credible region, " + model_name_for_plot, size=title_size)
# plt.show()
if save_plots:
if root_folder is None:
root_folder = default_root_folder
if model_folder is None:
model_folder = root_folder + '/' + default_model_folder[model]
bbox = Bbox(np.array([[0.3, -0.1], [7.3, 3.2]]))
plt.savefig(model_folder + "prediction_median_comparison." + ("pdf" if save_pdf else "png"), dpi=400,
bbox_inches=bbox)
plt.close()