Skip to content

Commit

Permalink
Merge pull request #1 from aash1999/cbam-imp
Browse files Browse the repository at this point in the history
Done necessary code changes for implementing CBAM and Involution to yolov5.
  • Loading branch information
aash1999 committed Oct 21, 2023
2 parents 4d687c8 + 7eff0ef commit b7715ca
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 2 deletions.
161 changes: 161 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,164 @@ def forward(self, x):
if isinstance(x, list):
x = torch.cat(x, 1)
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))

class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
"""
Initialize the Channel Attention module.
Args:
in_planes (int): Number of input channels.
ratio (int): Reduction ratio for the hidden channels in the channel attention block.
"""
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu = nn.ReLU()
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
"""
Forward pass of the Channel Attention module.
Args:
x (torch.Tensor): Input tensor.
Returns:
out (torch.Tensor): Output tensor after applying channel attention.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return out


class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
"""
Initialize the Spatial Attention module.
Args:
kernel_size (int): Size of the convolutional kernel for spatial attention.
"""
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
"""
Forward pass of the Spatial Attention module.
Args:
x (torch.Tensor): Input tensor.
Returns:
out (torch.Tensor): Output tensor after applying spatial attention.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return self.sigmoid(x)


class CBAM(nn.Module):
# ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size
def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16):
"""
Initialize the CBAM (Convolutional Block Attention Module) .
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
kernel_size (int): Size of the convolutional kernel.
shortcut (bool): Whether to use a shortcut connection.
g (int): Number of groups for grouped convolutions.
e (float): Expansion factor for hidden channels.
ratio (int): Reduction ratio for the hidden channels in the channel attention block.
"""
super(CBAM, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
self.channel_attention = ChannelAttention(c2, ratio)
self.spatial_attention = SpatialAttention(kernel_size)

def forward(self, x):
"""
Forward pass of the CBAM .
Args:
x (torch.Tensor): Input tensor.
Returns:
out (torch.Tensor): Output tensor after applying the CBAM bottleneck.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
x2 = self.cv2(self.cv1(x))
out = self.channel_attention(x2) * x2
out = self.spatial_attention(out) * out
return x + out if self.add else out



class Involution(nn.Module):

def __init__(self, c1, c2, kernel_size, stride):
"""
Initialize the Involution module.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
kernel_size (int): Size of the involution kernel.
stride (int): Stride for the involution operation.
"""
super(Involution, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.c1 = c1
reduction_ratio = 1
self.group_channels = 16
self.groups = self.c1 // self.group_channels
self.conv1 = Conv(
c1, c1 // reduction_ratio, 1)
self.conv2 = Conv(
c1 // reduction_ratio,
kernel_size ** 2 * self.groups,
1, 1)

if stride > 1:
self.avgpool = nn.AvgPool2d(stride, stride)
self.unfold = nn.Unfold(kernel_size, 1, (kernel_size - 1) // 2, stride)

def forward(self, x):
"""
Forward pass of the Involution module.
Args:
x (torch.Tensor): Input tensor.
Returns:
out (torch.Tensor): Output tensor after applying the involution operation.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
weight = self.conv2(x)
b, c, h, w = weight.shape
weight = weight.view(b, self.groups, self.kernel_size ** 2, h, w).unsqueeze(2)
out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size ** 2, h, w)
out = (weight * out).sum(dim=3).view(b, self.c1, h, w)

return out

2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in {
Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, CBAM, Involution}:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
Expand Down
3 changes: 2 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def init_seeds(seed=0, deterministic=False):
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
torch.use_deterministic_algorithms(True)
# torch.use_deterministic_algorithms(True)
torch.use_deterministic_algorithms(False, warn_only= True) #since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training
torch.backends.cudnn.deterministic = True
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
Expand Down

0 comments on commit b7715ca

Please sign in to comment.