Skip to content

Commit

Permalink
Extended skip connection further
Browse files Browse the repository at this point in the history
  • Loading branch information
manole-alexandru committed Apr 21, 2023
1 parent 0d57b4d commit 84841ad
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 14 deletions.
14 changes: 3 additions & 11 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,34 +851,26 @@ class Seg(nn.Module):

def __init__(self, in_channels):
super().__init__()
print('SEG in channels: ', in_channels)
self.cv1 = Conv(in_channels, 96, k=3)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.cv2 = Conv(192, 48, k=3)
self.cv3 = Conv(48, 16, k=3)
self.cv3 = Conv(96, 16, k=3)
self.cv4 = Conv(16, 1, act=False)
self.relu = nn.ReLU()

# self.dropout_weak = nn.Dropout(0.25)
self.dropout_normal = nn.Dropout(0.5)
# self.sigmoid = nn.Sigmoid()

def forward(self, x, skipped_input):
# print('----entry shape', x.shape, '---\n')
x = self.cv1(x)
x = self.upsample(x)
# Here you could use 2/1 96 - channels
x = torch.cat((x, skipped_input), 1)
x = torch.cat((x, skipped_input[0]), 1) # Skip connection

# print('----upsample shape', x.shape, '---\n')
x = self.cv2(x)
x = self.upsample(x)
x = torch.cat((x, skipped_input[1]), 1) # Skip connection

# x = self.dropout_normal(x)
x = self.cv3(x)
x = self.upsample(x)
# print('----out shape', x.shape, '---\n')

x = self.cv4(x)
return x

Expand Down
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True):

def forward(self, x):
old_x = x[:3]
new_skip_connect_info = x[3]
new_skip_connect_info = x[3:]
x = old_x
p = self.semantic_seg(x[0], new_skip_connect_info)
x = self.detect(self, x)
Expand Down
2 changes: 1 addition & 1 deletion models/yolov5m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ head:
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)

[[17, 20, 23, 2], 1, SemanticSegment, [nc, anchors]], # Detect(P3, P4, P5)
[[17, 20, 23, 2, 0], 1, SemanticSegment, [nc, anchors]], # Detect(P3, P4, P5)
]
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def parse_opt(known=False):


def main(opt, callbacks=Callbacks()):
print('\n---------- VERSION:', '#0017', '----------\n')
print('\n---------- VERSION:', '#0019', '----------\n')
# Checks
if RANK in {-1, 0}:
print_args(vars(opt))
Expand Down

0 comments on commit 84841ad

Please sign in to comment.