diff --git a/models/common.py b/models/common.py index 3a7724819104..a037ef82e7fa 100644 --- a/models/common.py +++ b/models/common.py @@ -852,9 +852,9 @@ class Seg(nn.Module): def __init__(self, in_channels): super().__init__() print('SEG in channels: ', in_channels) - self.cv1 = Conv(in_channels, 128, k=3) + self.cv1 = Conv(in_channels, 32, k=3) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.cv2 = Conv(128, 64, k=3) + self.cv2 = Conv(32, 64, k=3) self.cv3 = Conv(64, 64, k=3) self.cv4 = Conv(64, 1, act=False) self.relu = nn.ReLU()