Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cbam imp #1

Merged
merged 3 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,164 @@ def forward(self, x):
if isinstance(x, list):
x = torch.cat(x, 1)
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))

class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
"""
Initialize the Channel Attention module.
Args:
in_planes (int): Number of input channels.
ratio (int): Reduction ratio for the hidden channels in the channel attention block.
"""
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu = nn.ReLU()
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
"""
Forward pass of the Channel Attention module.
Args:
x (torch.Tensor): Input tensor.
Returns:
out (torch.Tensor): Output tensor after applying channel attention.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return out


class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
"""
Initialize the Spatial Attention module.
Args:
kernel_size (int): Size of the convolutional kernel for spatial attention.
"""
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
"""
Forward pass of the Spatial Attention module.
Args:
x (torch.Tensor): Input tensor.
Returns:
out (torch.Tensor): Output tensor after applying spatial attention.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return self.sigmoid(x)


class CBAM(nn.Module):
# ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size
def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16):
"""
Initialize the CBAM (Convolutional Block Attention Module) .
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
kernel_size (int): Size of the convolutional kernel.
shortcut (bool): Whether to use a shortcut connection.
g (int): Number of groups for grouped convolutions.
e (float): Expansion factor for hidden channels.
ratio (int): Reduction ratio for the hidden channels in the channel attention block.
"""
super(CBAM, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
self.channel_attention = ChannelAttention(c2, ratio)
self.spatial_attention = SpatialAttention(kernel_size)

def forward(self, x):
"""
Forward pass of the CBAM .
Args:
x (torch.Tensor): Input tensor.
Returns:
out (torch.Tensor): Output tensor after applying the CBAM bottleneck.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
x2 = self.cv2(self.cv1(x))
out = self.channel_attention(x2) * x2
out = self.spatial_attention(out) * out
return x + out if self.add else out



class Involution(nn.Module):

def __init__(self, c1, c2, kernel_size, stride):
"""
Initialize the Involution module.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
kernel_size (int): Size of the involution kernel.
stride (int): Stride for the involution operation.
"""
super(Involution, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.c1 = c1
reduction_ratio = 1
self.group_channels = 16
self.groups = self.c1 // self.group_channels
self.conv1 = Conv(
c1, c1 // reduction_ratio, 1)
self.conv2 = Conv(
c1 // reduction_ratio,
kernel_size ** 2 * self.groups,
1, 1)

if stride > 1:
self.avgpool = nn.AvgPool2d(stride, stride)
self.unfold = nn.Unfold(kernel_size, 1, (kernel_size - 1) // 2, stride)

def forward(self, x):
"""
Forward pass of the Involution module.
Args:
x (torch.Tensor): Input tensor.
Returns:
out (torch.Tensor): Output tensor after applying the involution operation.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
weight = self.conv2(x)
b, c, h, w = weight.shape
weight = weight.view(b, self.groups, self.kernel_size ** 2, h, w).unsqueeze(2)
out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size ** 2, h, w)
out = (weight * out).sum(dim=3).view(b, self.c1, h, w)

return out

2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in {
Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, CBAM, Involution}:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
Expand Down
3 changes: 2 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def init_seeds(seed=0, deterministic=False):
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
torch.use_deterministic_algorithms(True)
# torch.use_deterministic_algorithms(True)
torch.use_deterministic_algorithms(False, warn_only= True) #since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training
torch.backends.cudnn.deterministic = True
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
Expand Down