From b00415342937391553a9d89252488a23f9a1f31d Mon Sep 17 00:00:00 2001 From: ManoleAlexandru99 Date: Sat, 22 Apr 2023 10:43:04 +0300 Subject: [PATCH] Chose best #0019 version --- models/common.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/models/common.py b/models/common.py index 85e047906b50..ff3aacff5da7 100644 --- a/models/common.py +++ b/models/common.py @@ -850,20 +850,24 @@ def forward(self, x): class Seg(nn.Module): def __init__(self, in_channels): + super().__init__() - self.cv1 = Conv(in_channels, 96, k=3) + self.cv1 = Conv(in_channels, 32, k=3) + self.cv11 = Conv(96, 32, k=3) + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.cv2 = Conv(192, 48, k=3) - self.cv3 = Conv(48, 16, k=3) + self.cv2 = Conv(64, 32, k=3) + self.cv3 = Conv(32, 16, k=3) self.cv4 = Conv(16, 1, act=False) self.relu = nn.ReLU() - self.dropout_normal = nn.Dropout(0.5) def forward(self, x, skipped_input): + x = self.cv1(x) x = self.upsample(x) - x = torch.cat((x, skipped_input[0]), 1) # Skip connection + x2 = self.cv11(skipped_input[0]) + x = torch.cat((x, x2), 1) # Skip connection x = self.cv2(x) x = self.upsample(x)