From 5919f8414195ed2cf88a0b7cb8d0c14af24e19c0 Mon Sep 17 00:00:00 2001 From: ManoleAlexandru99 Date: Mon, 1 May 2023 10:15:19 +0300 Subject: [PATCH] Best #0019 version --- models/common.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/models/common.py b/models/common.py index 83d37a8ee35a..7f76edc64287 100644 --- a/models/common.py +++ b/models/common.py @@ -852,23 +852,23 @@ class Seg(nn.Module): def __init__(self, in_channels): super().__init__() print('SEG in channels: ', in_channels) - self.cv1 = Conv(in_channels, 96, k=3) + self.cv1 = Conv(in_channels, 32, k=3) + self.cv11 = Conv(96, 32, k=3) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.cv2 = Conv(192, 48, k=3) - self.cv3 = Conv(48, 16, k=3) + self.cv2 = Conv(64, 32, k=3) + self.cv3 = Conv(32, 16, k=3) self.cv4 = Conv(16, 1, act=False) - self.relu = nn.ReLU() - # self.dropout_weak = nn.Dropout(0.25) self.dropout_normal = nn.Dropout(0.5) - # self.sigmoid = nn.Sigmoid() def forward(self, x, skipped_input): # print('----entry shape', x.shape, '---\n') x = self.cv1(x) x = self.upsample(x) + + x2 = self.cv11(skipped_input[0]) # Here you could use 2/1 96 - channels - x = torch.cat((x, skipped_input), 1) + x = torch.cat((x, x2), 1) # print('----upsample shape', x.shape, '---\n') x = self.cv2(x)