Skip to content

Commit

Permalink
atten
Browse files Browse the repository at this point in the history
  • Loading branch information
argusswift committed Jan 6, 2021
1 parent 1d398a2 commit 550c9d5
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 13 deletions.
2 changes: 0 additions & 2 deletions config/yolov4_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))

DATA_PATH = osp.join(PROJECT_PATH, 'data')
# PROJECT_PATH = "E:\YOLOV4/data"
# PROJECT_PATH = "E:\YOLOV4/"


MODEL_TYPE = {
Expand Down
4 changes: 2 additions & 2 deletions eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __predict(self, img, test_shape, valid_scale, mode):
with torch.no_grad():
start_time = current_milli_time()
if self.showatt:
_, p_d, beta = self.model(img)
_, p_d, atten = self.model(img)
else:
_, p_d = self.model(img)
self.inference_time += current_milli_time() - start_time
Expand All @@ -131,7 +131,7 @@ def __predict(self, img, test_shape, valid_scale, mode):
pred_bbox, test_shape, (org_h, org_w), valid_scale
)
if self.showatt and len(img) and mode == 'det':
self.__show_heatmap(beta, org_img)
self.__show_heatmap(atten, org_img)
return bboxes

def __show_heatmap(self, beta, img):
Expand Down
6 changes: 3 additions & 3 deletions model/YOLOv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,14 @@ def __init__(self, weight_path=None, out_channels=255, resume=False, showatt=Fal
self.predict_net = PredictNet(feature_channels, out_channels)

def forward(self, x):
beta = None
atten = None
features = self.backbone(x)
if self.showatt:
features[-1], beta = self.attention(features[-1])
features[-1], atten = self.attention(features[-1])
features[-1] = self.spp(features[-1])
features = self.panet(features)
predicts = self.predict_net(features)
return predicts, beta
return predicts, atten


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions model/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, weight_path=None, resume=False, showatt=False):

def forward(self, x):
out = []
[x_s, x_m, x_l], beta = self.__yolov4(x)
[x_s, x_m, x_l], atten = self.__yolov4(x)

out.append(self.__head_s(x_s))
out.append(self.__head_m(x_m))
Expand All @@ -60,7 +60,7 @@ def forward(self, x):
else:
p, p_d = list(zip(*out))
if self.__showatt:
return p, torch.cat(p_d, 0), beta
return p, torch.cat(p_d, 0), atten
return p, torch.cat(p_d, 0)


Expand Down
8 changes: 4 additions & 4 deletions model/layers/global_context_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def spatial_pool(self, x):
context_mask = self.softmax(context_mask)
beta1 = context_mask
beta2 = torch.transpose(beta1, 1, 2)
beta = torch.matmul(beta2, beta1)
atten = torch.matmul(beta2, beta1)

# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(3)
Expand All @@ -57,13 +57,13 @@ def spatial_pool(self, x):
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)

return context, beta
return context, atten

def forward(self, x):
# [N, C, 1, 1]
context, beta = self.spatial_pool(x)
context, atten = self.spatial_pool(x)
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = x + channel_add_term

return out, beta
return out, atten

0 comments on commit 550c9d5

Please sign in to comment.