From 2aa857f6bc55dbc0c762d3250fb134c06172376f Mon Sep 17 00:00:00 2001 From: ManoleAlexandru99 Date: Fri, 21 Apr 2023 18:56:35 +0300 Subject: [PATCH] Fixes to new conection --- models/common.py | 6 +++++- models/yolo.py | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/models/common.py b/models/common.py index 2977e1f83b04..ea3190fdadd5 100644 --- a/models/common.py +++ b/models/common.py @@ -863,19 +863,23 @@ def __init__(self, in_channels): self.dropout_normal = nn.Dropout(0.5) # self.sigmoid = nn.Sigmoid() - def forward(self, x): + def forward(self, x, new_x): + print('new x:', new_x.shape) # print('----entry shape', x.shape, '---\n') x = self.cv1(x) x = self.upsample(x) + print('post unsample 1:', x.shape) # x = self.relu(x) # print('----upsample shape', x.shape, '---\n') x = self.cv2(x) x = self.upsample(x) + print('post unsample 2:', x.shape) # x = self.relu(x) # x = self.dropout_normal(x) x = self.cv3(x) x = self.upsample(x) + print('post unsample 3:', x.shape) # print('----out shape', x.shape, '---\n') # x = self.sigmoid(x) diff --git a/models/yolo.py b/models/yolo.py index 191a9204e75e..0f2c7e6e7afa 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -118,10 +118,9 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): def forward(self, x): old_x = x[:3] - new_skip_connect_info = x[4] - print(new_skip_connect_info.shape) + new_skip_connect_info = x[3] x = old_x - p = self.semantic_seg(x[0]) + p = self.semantic_seg(x[0], new_skip_connect_info) x = self.detect(self, x) return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])