forked from simo-an/FastFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
154 lines (124 loc) · 6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import torch
from torch import nn
from torchsummary import summary
import config as c
from freia_funcs import permute_layer, glow_coupling_layer, F_fully_connected, ReversibleGraphNet, OutputNode, \
InputNode, Node
import FrEIA.modules as Fm
import FrEIA.framework as Ff
import torchvision.models as models
import numpy as np
WEIGHT_DIR = './weights'
MODEL_DIR = './models'
def subnet_conv_1(c_in, c_out):
return nn.Sequential(nn.Conv2d(c_in, c.subnet_conv_dim, kernel_size=(1,1), padding='same'),
nn.ReLU(),
nn.Conv2d(c.subnet_conv_dim, c_out, kernel_size=(1,1), padding='same'))
def subnet_conv_3(c_in, c_out):
return nn.Sequential(nn.Conv2d(c_in, c.subnet_conv_dim, kernel_size=(3,3), padding='same'),
nn.ReLU(),
nn.Conv2d(c.subnet_conv_dim, c_out, kernel_size=(3,3), padding='same'))
'''
def nf_head(input_dim=c.n_feat):
nodes = list()
nodes.append(InputNode(input_dim, name='input'))
for k in range(c.n_coupling_blocks):
nodes.append(Node([nodes[-1].out0], permute_layer, {'seed': k}, name=F'permute_{k}'))
nodes.append(Node([nodes[-1].out0], glow_coupling_layer,
{'clamp': c.clamp_alpha, 'F_class': F_fully_connected,
'F_args': {'internal_size': c.fc_internal, 'dropout': c.dropout}},
name=F'fc_{k}'))
nodes.append(OutputNode([nodes[-1].out0], name='output'))
coder = ReversibleGraphNet(nodes)
return coder
'''
def nf_fast_flow(input_dim):
nodes = list()
nodes.append(Ff.InputNode(input_dim[0],input_dim[1], input_dim[2], name='input'))
# I add blocks with 3x3 and 1x1 convolutions alternatively. Before them, I add a fixed permutation of the channels
for k in range(c.n_coupling_blocks):
nodes.append(Ff.Node(nodes[-1],
Fm.PermuteRandom,
{'seed':k},
name=F'permute_high_res_{k}'))
if k % 2 == 0:
nodes.append(Ff.Node(nodes[-1],
Fm.GLOWCouplingBlock,
{'subnet_constructor':subnet_conv_3, 'clamp':1.2},
name=F'conv_high_res_{k}'))
else:
nodes.append(Ff.Node(nodes[-1],
Fm.GLOWCouplingBlock,
{'subnet_constructor':subnet_conv_1, 'clamp':1.2},
name=F'conv_high_res_{k}'))
nodes.append(Ff.OutputNode(nodes[-1], name='output'))
print(nodes)
coder = Ff.GraphINN(nodes)
print(coder)
return coder
class FastFlow(nn.Module):
def __init__(self):
super(FastFlow, self).__init__()
if c.extractor_name == "resnet18":
#self.feature_extractor = torch.hub.load('facebookresearch/deit:main', 'deit_base_distilled_patch16_224', pretrained=True)
#self.feature_extractor = torch.nn.Sequential(*(list(self.feature_extractor.children())[:-2])) # I remove the last two layers
self.feature_extractor = models.resnet18(pretrained=True)
# I take only the first blocks of the net, which has 64x64x64 as output
self.feature_extractor = torch.nn.Sequential(*(list(self.feature_extractor.children())[:5]))
# freeze the layers
for param in self.feature_extractor.parameters():
param.requires_grad = False
print(summary(self.feature_extractor, (3,256,256)))
#self.feature_extractor = torch.load('./pretrained/M48_448.pth') #sbagliato, carica solo i pesi, non il modello
#self.feature_extractor.eval() # to deactivate the dropout layers
# This input is unfortunately hardcoded. See the output dimensions of the feature extractor.
# Don't add the batch size (first number)
self.nf = nf_fast_flow((64,64,64))
elif c.extractor_name == "deit":
self.feature_extractor = torch.hub.load('facebookresearch/deit:main', 'deit_base_distilled_patch16_384', pretrained=True)
#print(help(self.feature_extractor ))
# I remove the last two layers
self.feature_extractor = torch.nn.Sequential(*(list(self.feature_extractor.children())[:-2]))
# freeze the layers
for param in self.feature_extractor.parameters():
param.requires_grad = False
print(summary(self.feature_extractor, (3,384,384)))
self.nf = nf_fast_flow((24,24,768))
def forward(self, x):
y_cat = list()
'''
for s in range(c.n_scales):
x_scaled = F.interpolate(x, size=c.img_size[0] // (2 ** s)) if s > 0 else x
#feat_s = self.feature_extractor.features(x_scaled)
feat_s = self.feature_extractor(x_scaled)
y_cat.append(torch.mean(feat_s, dim=(2, 3)))
'''
feat_s = self.feature_extractor(x)
#y_cat.append(feat_s)
#y = torch.cat(y_cat, dim=3)
#print(feat_s.size())
# I have to reshape the linearized output of deit back to a 2D image
# From (576,768) to (24,24,768). The first number is the batch size
if c.extractor_name == "deit":
dim_batch = feat_s.size(dim=0)
feat_s = feat_s.reshape(dim_batch,24,24,768)
#print(feat_s.size())
z, log_jac_det = self.nf(feat_s)
return z, log_jac_det
def save_model(model, filename):
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
torch.save(model, os.path.join(MODEL_DIR, filename))
def load_model(filename):
path = os.path.join(MODEL_DIR, filename)
model = torch.load(path)
return model
def save_weights(model, filename):
if not os.path.exists(WEIGHT_DIR):
os.makedirs(WEIGHT_DIR)
torch.save(model.state_dict(), os.path.join(WEIGHT_DIR, filename))
def load_weights(model, filename):
path = os.path.join(WEIGHT_DIR, filename)
model.load_state_dict(torch.load(path))
return model