diff --git a/data/hyps/hyp.hic-yolov5s.yaml b/data/hyps/hyp.hic-yolov5s.yaml new file mode 100644 index 000000000000..80f6e6dd95c4 --- /dev/null +++ b/data/hyps/hyp.hic-yolov5s.yaml @@ -0,0 +1,33 @@ +# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license +# hyperparameters for HIC-YOLOv5 for small object detection on VisDrone Dataset +# python train.py --hyp hyp.hic-yolov5s.yaml + +lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) +lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) +momentum: 0.937 # SGD momentum/Adam beta1 +weight_decay: 0.0005 # optimizer weight decay 5e-4 +warmup_epochs: 3.0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.1 # warmup initial bias lr +box: 0.05 # box loss gain +cls: 0.25 # cls loss gain +cls_pw: 1.0 # cls BCELoss positive_weight +obj: 0.5 # obj loss gain (scale with pixels) +obj_pw: 1.0 # obj BCELoss positive_weight +iou_t: 0.20 # IoU training threshold +anchor_t: 4.0 # anchor-multiple threshold +# anchors: 3 # anchors per output layer (0 to ignore) +fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) +hsv_h: 0.4 # image HSV-Hue augmentation (fraction) +hsv_s: 0.3 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.5 # image HSV-Value augmentation (fraction) +degrees: 0.2 # image rotation (+/- deg) +translate: 0.1 # image translation (+/- fraction) +scale: 0.4 # image scale (+/- gain) +shear: 0.0 # image shear (+/- deg) +perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # image flip up-down (probability) +fliplr: 0.5 # image flip left-right (probability) +mosaic: 1.0 # image mosaic (probability)s +mixup: 0.2 # image mixup (probability) +copy_paste: 0.1 # segment copy-paste (probability) diff --git a/models/common.py b/models/common.py index 75cc4e97bbc7..521551f273d0 100644 --- a/models/common.py +++ b/models/common.py @@ -881,3 +881,165 @@ def forward(self, x): if isinstance(x, list): x = torch.cat(x, 1) return self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) + + +# contributed by @aash1999 +class ChannelAttention(nn.Module): + + def __init__(self, in_planes, ratio=16): + """ + Initialize the Channel Attention module. + + Args: + in_planes (int): Number of input channels. + ratio (int): Reduction ratio for the hidden channels in the channel attention block. + """ + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) + self.relu = nn.ReLU() + self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + """ + Forward pass of the Channel Attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + out (torch.Tensor): Output tensor after applying channel attention. + """ + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + avg_out = self.f2(self.relu(self.f1(self.avg_pool(x)))) + max_out = self.f2(self.relu(self.f1(self.max_pool(x)))) + out = self.sigmoid(avg_out + max_out) + return out + + +# contributed by @aash1999 +class SpatialAttention(nn.Module): + + def __init__(self, kernel_size=7): + """ + Initialize the Spatial Attention module. + + Args: + kernel_size (int): Size of the convolutional kernel for spatial attention. + """ + super().__init__() + assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + padding = 3 if kernel_size == 7 else 1 + self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + """ + Forward pass of the Spatial Attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + out (torch.Tensor): Output tensor after applying spatial attention. + """ + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv(x) + return self.sigmoid(x) + + +# contributed by @aash1999 +class CBAM(nn.Module): + # ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size + def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16): + """ + Initialize the CBAM (Convolutional Block Attention Module) . + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + kernel_size (int): Size of the convolutional kernel. + shortcut (bool): Whether to use a shortcut connection. + g (int): Number of groups for grouped convolutions. + e (float): Expansion factor for hidden channels. + ratio (int): Reduction ratio for the hidden channels in the channel attention block. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.add = shortcut and c1 == c2 + self.channel_attention = ChannelAttention(c2, ratio) + self.spatial_attention = SpatialAttention(kernel_size) + + def forward(self, x): + """ + Forward pass of the CBAM . + + Args: + x (torch.Tensor): Input tensor. + + Returns: + out (torch.Tensor): Output tensor after applying the CBAM bottleneck. + """ + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + x2 = self.cv2(self.cv1(x)) + out = self.channel_attention(x2) * x2 + out = self.spatial_attention(out) * out + return x + out if self.add else out + + +# contributed by @aash1999 +class Involution(nn.Module): + + def __init__(self, c1, c2, kernel_size, stride): + """ + Initialize the Involution module. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + kernel_size (int): Size of the involution kernel. + stride (int): Stride for the involution operation. + """ + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.c1 = c1 + reduction_ratio = 1 + self.group_channels = 16 + self.groups = self.c1 // self.group_channels + self.conv1 = Conv(c1, c1 // reduction_ratio, 1) + self.conv2 = Conv(c1 // reduction_ratio, kernel_size ** 2 * self.groups, 1, 1) + + if stride > 1: + self.avgpool = nn.AvgPool2d(stride, stride) + self.unfold = nn.Unfold(kernel_size, 1, (kernel_size - 1) // 2, stride) + + def forward(self, x): + """ + Forward pass of the Involution module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + out (torch.Tensor): Output tensor after applying the involution operation. + """ + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + weight = self.conv2(x) + b, c, h, w = weight.shape + weight = weight.view(b, self.groups, self.kernel_size ** 2, h, w).unsqueeze(2) + out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size ** 2, h, w) + out = (weight * out).sum(dim=3).view(b, self.c1, h, w) + + return out diff --git a/models/hub/yolov5s-cbam-involution.yaml b/models/hub/yolov5s-cbam-involution.yaml new file mode 100644 index 000000000000..9ac132e1cd78 --- /dev/null +++ b/models/hub/yolov5s-cbam-involution.yaml @@ -0,0 +1,60 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 10 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple +anchors: + - [2.9434,4.0435, 3.8626,8.5592, 6.8534, 5.9391] + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 3, CBAM, [1024, 3]], + [-1, 1, SPPF, [1024, 5]], # 10 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Involution, [1024, 1, 1]], + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 15 + + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [512, False]], # 19 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 2], 1, Concat, [1]], + [-1, 3, C3, [256, False]], # 23 160*160 p2 head + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 19], 1, Concat, [1]], + [-1, 3, C3, [512, False]], # 26 80*80 p3 head + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 15], 1, Concat, [1]], + [-1, 3, C3, [256, False]], # 29 40*40 p4 head + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 11], 1, Concat, [1]], + [-1, 3, C3, [1024, False]], # 32 20*20 p5 head + + [[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5) + ] diff --git a/models/yolo.py b/models/yolo.py index 4f4d567bec73..ad78d1fbd486 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -316,7 +316,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain if m in { Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, - BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}: + BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, CBAM, Involution}: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) diff --git a/utils/general.py b/utils/general.py index 135141e21436..f31ca46f8cff 100644 --- a/utils/general.py +++ b/utils/general.py @@ -264,7 +264,8 @@ def init_seeds(seed=0, deterministic=False): torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 - torch.use_deterministic_algorithms(True) + # since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training + torch.use_deterministic_algorithms(False, warn_only=True) torch.backends.cudnn.deterministic = True os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed)