diff --git a/models/common.py b/models/common.py index 5ccf7d2b5545..3a7724819104 100644 --- a/models/common.py +++ b/models/common.py @@ -852,11 +852,11 @@ class Seg(nn.Module): def __init__(self, in_channels): super().__init__() print('SEG in channels: ', in_channels) - self.cv1 = Conv(in_channels, 64, k=3) + self.cv1 = Conv(in_channels, 128, k=3) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.cv2 = Conv(64, 64, k=3) - self.cv3 = Conv(64, 32, k=3) - self.cv4 = Conv(32, 1, act=False) + self.cv2 = Conv(128, 64, k=3) + self.cv3 = Conv(64, 64, k=3) + self.cv4 = Conv(64, 1, act=False) self.relu = nn.ReLU() # self.dropout_weak = nn.Dropout(0.25)