From 0d57b4d9197624f9dcfd979faafaeafdf4ec7856 Mon Sep 17 00:00:00 2001 From: ManoleAlexandru99 Date: Fri, 21 Apr 2023 19:32:15 +0300 Subject: [PATCH] More U-Net Like UOLO (#0019) --- models/common.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/models/common.py b/models/common.py index ea3190fdadd5..83d37a8ee35a 100644 --- a/models/common.py +++ b/models/common.py @@ -852,36 +852,32 @@ class Seg(nn.Module): def __init__(self, in_channels): super().__init__() print('SEG in channels: ', in_channels) - self.cv1 = Conv(in_channels, 8, k=3) + self.cv1 = Conv(in_channels, 96, k=3) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.cv2 = Conv(8, 16, k=3) - self.cv3 = Conv(16, 8, k=3) - self.cv4 = Conv(8, 1, act=False) + self.cv2 = Conv(192, 48, k=3) + self.cv3 = Conv(48, 16, k=3) + self.cv4 = Conv(16, 1, act=False) self.relu = nn.ReLU() # self.dropout_weak = nn.Dropout(0.25) self.dropout_normal = nn.Dropout(0.5) # self.sigmoid = nn.Sigmoid() - def forward(self, x, new_x): - print('new x:', new_x.shape) + def forward(self, x, skipped_input): # print('----entry shape', x.shape, '---\n') x = self.cv1(x) x = self.upsample(x) - print('post unsample 1:', x.shape) - # x = self.relu(x) + # Here you could use 2/1 96 - channels + x = torch.cat((x, skipped_input), 1) + # print('----upsample shape', x.shape, '---\n') x = self.cv2(x) x = self.upsample(x) - print('post unsample 2:', x.shape) - # x = self.relu(x) # x = self.dropout_normal(x) x = self.cv3(x) x = self.upsample(x) - print('post unsample 3:', x.shape) # print('----out shape', x.shape, '---\n') - # x = self.sigmoid(x) x = self.cv4(x) return x