-
Notifications
You must be signed in to change notification settings - Fork 5
/
refine_subnet.py
61 lines (49 loc) · 2.18 KB
/
refine_subnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from layer_utils import *
# dimensions of image [batch_size, channels, height, width]
class RefineSubnet(nn.Module):
def __init__(self):
super(RefineSubnet, self).__init__()
# Bilinear upsampling layer
#self.upsample = nn.Upsample(size=1024, mode='bilinear')
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
# Initial convolution layers
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1 = nn.InstanceNorm2d(32, affine=True)
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.in2 = nn.InstanceNorm2d(64, affine=True)
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
self.in3 = nn.InstanceNorm2d(128, affine=True)
# Residual layers
self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128)
self.res3 = ResidualBlock(128)
# Upsampling Layers
self.rezconv1 = ResizeConvLayer(128, 64, kernel_size=3, stride=1)
self.in4 = nn.InstanceNorm2d(64, affine=True)
self.rezconv2 = ResizeConvLayer(64, 32, kernel_size=3, stride=1)
self.in5 = nn.InstanceNorm2d(32, affine=True)
self.rezconv3 = ConvLayer(32, 3, kernel_size=3, stride=1)
# Non-linearities
self.relu = nn.ReLU()
def forward(self, X):
in_X = X
#if self.training == False: in_X = self.upsample(in_X) # Only apply upsampling during test
# resized input image is the content target
resized_input_img = in_X.clone()
y = self.relu(self.in1(self.conv1(in_X)))
y = self.relu(self.in2(self.conv2(y)))
y = self.relu(self.in3(self.conv3(y)))
y = self.res1(y)
y = self.res2(y)
y = self.res3(y)
y = self.relu(self.in4(self.rezconv1(y)))
y = self.relu(self.in5(self.rezconv2(y)))
y = self.rezconv3(y)
y = y + resized_input_img
# Clamp image to be in range [0,1] after denormalization
y[0][0].clamp_((0-0.485)/0.299, (1-0.485)/0.299)
y[0][1].clamp_((0-0.456)/0.224, (1-0.456)/0.224)
y[0][2].clamp_((0-0.406)/0.225, (1-0.406)/0.225)
return y, resized_input_img