forked from yghlc/Unet_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
94 lines (75 loc) · 3.08 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as functional
def add_conv_stage(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True, useBN=False):
if useBN:
return nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
nn.BatchNorm2d(dim_out),
nn.LeakyReLU(0.1),
nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
nn.BatchNorm2d(dim_out),
nn.LeakyReLU(0.1)
)
else:
return nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
nn.ReLU(),
nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
nn.ReLU()
)
def add_merge_stage(ch_coarse, ch_fine, in_coarse, in_fine, upsample):
conv = nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False)
torch.cat(conv, in_fine)
return nn.Sequential(
nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False)
)
upsample(in_coarse)
def upsample(ch_coarse, ch_fine):
return nn.Sequential(
nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False),
nn.ReLU()
)
class Net(nn.Module):
def __init__(self, useBN=False):
super(Net, self).__init__()
self.conv1 = add_conv_stage(3, 32, useBN=useBN)
self.conv2 = add_conv_stage(32, 64, useBN=useBN)
self.conv3 = add_conv_stage(64, 128, useBN=useBN)
self.conv4 = add_conv_stage(128, 256, useBN=useBN)
self.conv5 = add_conv_stage(256, 512, useBN=useBN)
self.conv4m = add_conv_stage(512, 256, useBN=useBN)
self.conv3m = add_conv_stage(256, 128, useBN=useBN)
self.conv2m = add_conv_stage(128, 64, useBN=useBN)
self.conv1m = add_conv_stage( 64, 32, useBN=useBN)
self.conv0 = nn.Sequential(
nn.Conv2d(32, 1, 3, 1, 1),
nn.Sigmoid()
)
self.max_pool = nn.MaxPool2d(2)
self.upsample54 = upsample(512, 256)
self.upsample43 = upsample(256, 128)
self.upsample32 = upsample(128, 64)
self.upsample21 = upsample(64 , 32)
## weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
conv1_out = self.conv1(x)
#return self.upsample21(conv1_out)
conv2_out = self.conv2(self.max_pool(conv1_out))
conv3_out = self.conv3(self.max_pool(conv2_out))
conv4_out = self.conv4(self.max_pool(conv3_out))
conv5_out = self.conv5(self.max_pool(conv4_out))
conv5m_out = torch.cat((self.upsample54(conv5_out), conv4_out), 1)
conv4m_out = self.conv4m(conv5m_out)
conv4m_out_ = torch.cat((self.upsample43(conv4m_out), conv3_out), 1)
conv3m_out = self.conv3m(conv4m_out_)
conv3m_out_ = torch.cat((self.upsample32(conv3m_out), conv2_out), 1)
conv2m_out = self.conv2m(conv3m_out_)
conv2m_out_ = torch.cat((self.upsample21(conv2m_out), conv1_out), 1)
conv1m_out = self.conv1m(conv2m_out_)
conv0_out = self.conv0(conv1m_out)
return conv0_out