Skip to content

Commit

Permalink
d (#13255)
Browse files Browse the repository at this point in the history
* test

* .

* CBAM

* fix CBAM

* change pool medthod

* use .sigmoid
  • Loading branch information
yothinsaengs authored Sep 16, 2024
1 parent f732292 commit e2d6b25
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 3 deletions.
48 changes: 45 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,51 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
c_ = int(c2 * e) # hidden channels
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))

class CAM(nn.Module):
def __init__(self, channels, r):
super(CAM, self).__init__()
self.channels = channels
self.r = r
self.linear = nn.Sequential(
nn.Linear(in_features=self.channels, out_features=self.channels//self.r, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=self.channels//self.r, out_features=self.channels, bias=True))

def forward(self, x):
b, c, h, w = x.size()
max = nn.functional.max_pool2d(x,(h,w))
avg = nn.functional.avg_pool2d(x,(h,w))
linear_max = self.linear(max.view(b,c)).view(b, c, 1, 1)
linear_avg = self.linear(avg.view(b,c)).view(b, c, 1, 1)
output = linear_max + linear_avg
output = nn.functional.sigmoid(output) * x
return output

class SAM(nn.Module):
def __init__(self, bias=False):
super(SAM, self).__init__()
self.bias = bias
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, bias=self.bias)

def forward(self, x):
max = torch.max(x,1)[0].unsqueeze(1)
avg = torch.mean(x,1).unsqueeze(1)
concat = torch.cat((max,avg), dim=1)
output = self.conv(concat)
output = output.sigmoid() * x
return output
class CBAM(nn.Module):
def __init__(self, channels, r):
super(CBAM, self).__init__()
self.channels = channels
self.r = r
self.sam = SAM(bias=False)
self.cam = CAM(channels=self.channels, r=self.r)

def forward(self, x):
output = self.cam(x)
output = self.sam(output)
return output + x

class SPP(nn.Module):
"""Implements Spatial Pyramid Pooling (SPP) for feature extraction, ref: https://arxiv.org/abs/1406.4729."""
Expand Down Expand Up @@ -1092,9 +1137,6 @@ class Classify(nn.Module):
def __init__(
self, c1, c2, k=1, s=1, p=None, g=1, dropout_p=0.0
): # ch_in, ch_out, kernel, stride, padding, groups, dropout probability
"""Initializes YOLOv5 classification head with convolution, pooling, and dropout layers for input to output
channel transformation.
"""
super().__init__()
c_ = 1280 # efficientnet_b0 size
self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
Expand Down
5 changes: 5 additions & 0 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
GhostBottleneck,
GhostConv,
Proto,
CBAM,
)
from models.experimental import MixConv2d
from utils.autoanchor import check_anchor_order
Expand Down Expand Up @@ -421,6 +422,7 @@ def parse_model(d, ch):
DWConvTranspose2d,
C3x,
}:
"""c1 = number previous chanel ,c2 = number output chanel"""
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, ch_mul)
Expand All @@ -444,6 +446,9 @@ def parse_model(d, ch):
c2 = ch[f] * args[0] ** 2
elif m is Expand:
c2 = ch[f] // args[0] ** 2
elif m is CBAM:
c1 = ch[f]
args = [c1, args[0]]
else:
c2 = ch[f]

Expand Down
52 changes: 52 additions & 0 deletions models/yolov5s_CBAM.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Ultralytics YOLOv5 🚀, AGPL-3.0 license

# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10, 13, 16, 30, 33, 23] # P3/8
- [30, 61, 62, 45, 59, 119] # P4/16
- [116, 90, 156, 198, 373, 326] # P5/32

# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]

# YOLOv5 v6.0 head
head: [
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13

[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 3, CBAM, [16]], # 18 (CBAM)

[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 21 (P4/16-medium)
[-1, 3, CBAM, [16]], # 22 (CBAM)

[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 25 (P5/32-large)
[-1, 3, CBAM, [16]], # 26 (CBAM)

[[18, 22, 26], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]

0 comments on commit e2d6b25

Please sign in to comment.