diff --git a/models/common.py b/models/common.py index fc071eba5cd3..aba5741ff28a 100644 --- a/models/common.py +++ b/models/common.py @@ -854,10 +854,10 @@ def __init__(self, in_channels): self.cv1 = Conv(in_channels, 32, k=3) self.cv11 = Conv(96, 32, k=3) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - self.cv2 = Conv(64, 8, k=3) - self.cv22 = Conv(48, 8, k=3) - self.cv3 = Conv(16, 8, k=3) - self.cv4 = Conv(8, 1, act=False) + self.cv2 = Conv(64, 16, k=3) + self.cv22 = Conv(48, 16, k=3) + self.cv3 = Conv(32, 4, k=3) + self.cv4 = Conv(4, 1, act=False) self.relu = nn.ReLU() self.dropout_normal = nn.Dropout(0.5)