-
Notifications
You must be signed in to change notification settings - Fork 20
/
SSPSR.py
123 lines (92 loc) · 4.12 KB
/
SSPSR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import math
import torch.nn as nn
from common import *
class SSB(nn.Module):
def __init__(self, n_feats, kernel_size, act, res_scale, conv=default_conv):
super(SSB, self).__init__()
self.spa = ResBlock(conv, n_feats, kernel_size, act=act, res_scale=res_scale)
self.spc = ResAttentionBlock(conv, n_feats, 1, act=act, res_scale=res_scale)
def forward(self, x):
return self.spc(self.spa(x))
class SSPN(nn.Module):
def __init__(self, n_feats, n_blocks, act, res_scale):
super(SSPN, self).__init__()
kernel_size = 3
m = []
for i in range(n_blocks):
m.append(SSB(n_feats, kernel_size, act=act, res_scale=res_scale))
self.net = nn.Sequential(*m)
def forward(self, x):
res = self.net(x)
res += x
return res
# a single branch of proposed SSPSR
class BranchUnit(nn.Module):
def __init__(self, n_colors, n_feats, n_blocks, act, res_scale, up_scale, use_tail=True, conv=default_conv):
super(BranchUnit, self).__init__()
kernel_size = 3
self.head = conv(n_colors, n_feats, kernel_size)
self.body = SSPN(n_feats, n_blocks, act, res_scale)
self.upsample = Upsampler(conv, up_scale, n_feats)
self.tail = None
if use_tail:
self.tail = conv(n_feats, n_colors, kernel_size)
def forward(self, x):
y = self.head(x)
y = self.body(y)
y = self.upsample(y)
if self.tail is not None:
y = self.tail(y)
return y
class SSPSR(nn.Module):
def __init__(self, n_subs, n_ovls, n_colors, n_blocks, n_feats, n_scale, res_scale, use_share=True, conv=default_conv):
super(SSPSR, self).__init__()
kernel_size = 3
self.shared = use_share
act = nn.ReLU(True)
# calculate the group number (the number of branch networks)
self.G = math.ceil((n_colors - n_ovls) / (n_subs - n_ovls))
# calculate group indices
self.start_idx = []
self.end_idx = []
for g in range(self.G):
sta_ind = (n_subs - n_ovls) * g
end_ind = sta_ind + n_subs
if end_ind > n_colors:
end_ind = n_colors
sta_ind = n_colors - n_subs
self.start_idx.append(sta_ind)
self.end_idx.append(end_ind)
if self.shared:
self.branch = BranchUnit(n_subs, n_feats, n_blocks, act, res_scale, up_scale=n_scale//2, conv=default_conv)
# up_scale=n_scale//2 means that we upsample the LR input n_scale//2 at the branch network, and then conduct 2 times upsampleing at the global network
else:
self.branch = nn.ModuleList()
for i in range(self.G):
self.branch.append(BranchUnit(n_subs, n_feats, n_blocks, act, res_scale, up_scale=2, conv=default_conv))
self.trunk = BranchUnit(n_colors, n_feats, n_blocks, act, res_scale, up_scale=2, use_tail=False, conv=default_conv)
self.skip_conv = conv(n_colors, n_feats, kernel_size)
self.final = conv(n_feats, n_colors, kernel_size)
self.sca = n_scale//2
def forward(self, x, lms):
b, c, h, w = x.shape
# Initialize intermediate “result”, which is upsampled with n_scale//2 times
y = torch.zeros(b, c, self.sca * h, self.sca * w).cuda()
channel_counter = torch.zeros(c).cuda()
for g in range(self.G):
sta_ind = self.start_idx[g]
end_ind = self.end_idx[g]
xi = x[:, sta_ind:end_ind, :, :]
if self.shared:
xi = self.branch(xi)
else:
xi = self.branch[g](xi)
y[:, sta_ind:end_ind, :, :] += xi
channel_counter[sta_ind:end_ind] = channel_counter[sta_ind:end_ind] + 1
# intermediate “result” is averaged according to their spectral indices
y = y / channel_counter.unsqueeze(1).unsqueeze(2)
y = self.trunk(y)
y = y + self.skip_conv(lms)
y = self.final(y)
return y