Skip to content

Commit

Permalink
More U-Net Like UOLO (ultralytics#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
manole-alexandru committed Apr 21, 2023
1 parent 2aa857f commit 0d57b4d
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,36 +852,32 @@ class Seg(nn.Module):
def __init__(self, in_channels):
super().__init__()
print('SEG in channels: ', in_channels)
self.cv1 = Conv(in_channels, 8, k=3)
self.cv1 = Conv(in_channels, 96, k=3)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.cv2 = Conv(8, 16, k=3)
self.cv3 = Conv(16, 8, k=3)
self.cv4 = Conv(8, 1, act=False)
self.cv2 = Conv(192, 48, k=3)
self.cv3 = Conv(48, 16, k=3)
self.cv4 = Conv(16, 1, act=False)
self.relu = nn.ReLU()

# self.dropout_weak = nn.Dropout(0.25)
self.dropout_normal = nn.Dropout(0.5)
# self.sigmoid = nn.Sigmoid()

def forward(self, x, new_x):
print('new x:', new_x.shape)
def forward(self, x, skipped_input):
# print('----entry shape', x.shape, '---\n')
x = self.cv1(x)
x = self.upsample(x)
print('post unsample 1:', x.shape)
# x = self.relu(x)
# Here you could use 2/1 96 - channels
x = torch.cat((x, skipped_input), 1)

# print('----upsample shape', x.shape, '---\n')
x = self.cv2(x)
x = self.upsample(x)
print('post unsample 2:', x.shape)

# x = self.relu(x)
# x = self.dropout_normal(x)
x = self.cv3(x)
x = self.upsample(x)
print('post unsample 3:', x.shape)
# print('----out shape', x.shape, '---\n')
# x = self.sigmoid(x)

x = self.cv4(x)
return x
Expand Down

0 comments on commit 0d57b4d

Please sign in to comment.