-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
66 lines (54 loc) · 2.05 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torchvision
from torchvision import transforms
_dataset_name = ["mnist", "cifar10", "gtsrb", "imagenet"]
_mean = {
"mnist": [0.5, 0.5, 0.5],
"cifar10": [0.4914, 0.4822, 0.4465],
"gtsrb": [0.3337, 0.3064, 0.3171],
"imagenet": [0.485, 0.456, 0.406],
}
_std = {
"mnist": [0.5, 0.5, 0.5],
"cifar10": [0.2470, 0.2435, 0.2616],
"gtsrb": [0.2672, 0.2564, 0.2629],
"imagenet": [0.229, 0.224, 0.225],
}
_size = {
"mnist": (28, 28),
"cifar10": (32, 32),
"gtsrb": (32, 32),
"imagenet": (224, 224),
}
def get_totensor_topil():
return transforms.ToTensor(), transforms.ToPILImage()
def get_normalize_unnormalize(dataset):
assert dataset in _dataset_name
mean = torch.FloatTensor(_mean[dataset])
std = torch.FloatTensor(_std[dataset])
normalize = transforms.Normalize(mean, std)
unnormalize = transforms.Normalize(- mean / std, 1 / std)
return normalize, unnormalize
def get_bounds(x, dataset):
normalize, _ = get_normalize_unnormalize(dataset)
upperbound = normalize(torch.ones_like(x))
lowerbound = normalize(torch.zeros_like(x))
return lowerbound, upperbound
def get_normalized_clip(dataset):
normalize, _ = get_normalize_unnormalize(dataset)
return lambda x : torch.min(torch.max(x, normalize(torch.zeros_like(x))), normalize(torch.ones_like(x)))
def get_resize(size):
if isinstance(size, str):
assert size in _dataset_name
size = _size[size]
return transforms.Resize(size)
def get_preprocess_deprocess(dataset, size=None):
totensor, topil = get_totensor_topil()
normalize, unnormalize = get_normalize_unnormalize(dataset)
if size is None:
preprocess = transforms.Compose([totensor, normalize])
deprocess = transforms.Compose([unnormalize, topil])
else:
preprocess = transforms.Compose([get_resize(size), totensor, normalize])
deprocess = transforms.Compose([unnormalize, topil])
return preprocess, deprocess