From e7e1cdb0c263c48c53bd0cb5bf729687794ce2d7 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Sat, 21 Oct 2023 22:18:16 +0530 Subject: [PATCH 01/21] imp of CBAM + Involution at common.py --- models/common.py | 161 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/models/common.py b/models/common.py index 75cc4e97bbc7..cfcde6f8754f 100644 --- a/models/common.py +++ b/models/common.py @@ -881,3 +881,164 @@ def forward(self, x): if isinstance(x, list): x = torch.cat(x, 1) return self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) + +class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=16): + """ + Initialize the Channel Attention module. + + Args: + in_planes (int): Number of input channels. + ratio (int): Reduction ratio for the hidden channels in the channel attention block. + """ + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) + self.relu = nn.ReLU() + self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + """ + Forward pass of the Channel Attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + out (torch.Tensor): Output tensor after applying channel attention. + """ + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + avg_out = self.f2(self.relu(self.f1(self.avg_pool(x)))) + max_out = self.f2(self.relu(self.f1(self.max_pool(x)))) + out = self.sigmoid(avg_out + max_out) + return out + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + """ + Initialize the Spatial Attention module. + + Args: + kernel_size (int): Size of the convolutional kernel for spatial attention. + """ + super(SpatialAttention, self).__init__() + assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + padding = 3 if kernel_size == 7 else 1 + self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + """ + Forward pass of the Spatial Attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + out (torch.Tensor): Output tensor after applying spatial attention. + """ + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv(x) + return self.sigmoid(x) + + +class CBAM(nn.Module): + # ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size + def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16): + """ + Initialize the CBAM (Convolutional Block Attention Module) . + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + kernel_size (int): Size of the convolutional kernel. + shortcut (bool): Whether to use a shortcut connection. + g (int): Number of groups for grouped convolutions. + e (float): Expansion factor for hidden channels. + ratio (int): Reduction ratio for the hidden channels in the channel attention block. + """ + super(CBAM, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.add = shortcut and c1 == c2 + self.channel_attention = ChannelAttention(c2, ratio) + self.spatial_attention = SpatialAttention(kernel_size) + + def forward(self, x): + """ + Forward pass of the CBAM . + + Args: + x (torch.Tensor): Input tensor. + + Returns: + out (torch.Tensor): Output tensor after applying the CBAM bottleneck. + """ + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + x2 = self.cv2(self.cv1(x)) + out = self.channel_attention(x2) * x2 + out = self.spatial_attention(out) * out + return x + out if self.add else out + + + +class Involution(nn.Module): + + def __init__(self, c1, c2, kernel_size, stride): + """ + Initialize the Involution module. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + kernel_size (int): Size of the involution kernel. + stride (int): Stride for the involution operation. + """ + super(Involution, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.c1 = c1 + reduction_ratio = 1 + self.group_channels = 16 + self.groups = self.c1 // self.group_channels + self.conv1 = Conv( + c1, c1 // reduction_ratio, 1) + self.conv2 = Conv( + c1 // reduction_ratio, + kernel_size ** 2 * self.groups, + 1, 1) + + if stride > 1: + self.avgpool = nn.AvgPool2d(stride, stride) + self.unfold = nn.Unfold(kernel_size, 1, (kernel_size - 1) // 2, stride) + + def forward(self, x): + """ + Forward pass of the Involution module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + out (torch.Tensor): Output tensor after applying the involution operation. + """ + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + weight = self.conv2(x) + b, c, h, w = weight.shape + weight = weight.view(b, self.groups, self.kernel_size ** 2, h, w).unsqueeze(2) + out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size ** 2, h, w) + out = (weight * out).sum(dim=3).view(b, self.c1, h, w) + + return out + From 16fd02c5e21f3737fe5b20337bbce7e788804f48 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Sat, 21 Oct 2023 22:20:24 +0530 Subject: [PATCH 02/21] import CBAm and Involution into yolo.py --- models/yolo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/yolo.py b/models/yolo.py index 4f4d567bec73..ad78d1fbd486 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -316,7 +316,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain if m in { Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, - BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}: + BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, CBAM, Involution}: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) From 7eff0ef6f04d79735d3bfad3a2a77c07c5bbee91 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Sat, 21 Oct 2023 22:25:40 +0530 Subject: [PATCH 03/21] handle GPU err on use_deterministic_algorithms --- utils/general.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/general.py b/utils/general.py index 135141e21436..ba799a174039 100644 --- a/utils/general.py +++ b/utils/general.py @@ -264,7 +264,8 @@ def init_seeds(seed=0, deterministic=False): torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 - torch.use_deterministic_algorithms(True) + # torch.use_deterministic_algorithms(True) + torch.use_deterministic_algorithms(False, warn_only= True) #since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training torch.backends.cudnn.deterministic = True os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed) From 55ea408eafcb61affc143a5ecf080e2412e9f4be Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Sat, 21 Oct 2023 22:31:04 +0530 Subject: [PATCH 04/21] added arch. backbone to /models/ --- models/yolo5m-cbam-involution.yaml | 60 ++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 models/yolo5m-cbam-involution.yaml diff --git a/models/yolo5m-cbam-involution.yaml b/models/yolo5m-cbam-involution.yaml new file mode 100644 index 000000000000..326940f7d6a8 --- /dev/null +++ b/models/yolo5m-cbam-involution.yaml @@ -0,0 +1,60 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 10 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple +anchors: + - [2.9434,4.0435, 3.8626,8.5592, 6.8534, 5.9391] + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 3, CBAM, [1024, 3]], + [-1, 1, SPPF, [1024, 5]], # 10 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Involution, [1024, 1, 1]], + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 15 + + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [512, False]], # 19 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 2], 1, Concat, [1]], + [-1, 3, C3, [256, False]], # 23 160*160 p2 head + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 19], 1, Concat, [1]], + [-1, 3, C3, [512, False]], # 26 80*80 p3 head + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 15], 1, Concat, [1]], + [-1, 3, C3, [256, False]], # 29 40*40 p4 head + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 11], 1, Concat, [1]], + [-1, 3, C3, [1024, False]], # 32 20*20 p5 head + + [[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5) + ] \ No newline at end of file From 02469f2ab0557281af9ab7253f103f1708925a42 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Sat, 21 Oct 2023 22:36:39 +0530 Subject: [PATCH 05/21] readme update --- README.md | 504 +++--------------------------------------------------- 1 file changed, 22 insertions(+), 482 deletions(-) diff --git a/README.md b/README.md index a32acb3f3629..5c4580eaf323 100644 --- a/README.md +++ b/README.md @@ -1,501 +1,41 @@ -
-

- - - -

+# HIC-YOLOv5: Improved YOLOv5 for Small Object Detection -[English](README.md) | [įŽ€äŊ“中文](README.zh-CN.md) -
+## Overview -
- YOLOv5 CI - YOLOv5 Citation - Docker Pulls -
- Run on Gradient - Open In Colab - Open In Kaggle -
-
+This repository contains the code for HIC-YOLOv5, an improved version of YOLOv5 tailored for small object detection. The improvements are based on the paper [HIC-YOLOv5: Improved YOLOv5 For Small Object Detection](https://arxiv.org/pdf/2309.16393v1.pdf). -YOLOv5 🚀 is the world's most loved vision AI, representing Ultralytics open-source research into future vision AI methods, incorporating lessons learned and best practices evolved over thousands of hours of research and development. +HIC-YOLOv5 incorporates Channel Attention Block (CBAM) and Involution modules for enhanced object detection, making it suitable for both CPU and GPU training. -We hope that the resources here will help you get the most out of YOLOv5. Please browse the YOLOv5 Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions! +## Installation -To request an Enterprise License please complete the form at [Ultralytics Licensing](https://ultralytics.com/license). +The installation process for HIC-YOLOv5 is identical to the YOLOv5 repository. You can follow the installation instructions provided in the [YOLOv5 GitHub repository](https://github.com/ultralytics/yolov5). -
- - - - - - - - - - - - - - - - - - - - -
+## Usage -
-
- -##
YOLOv8 🚀 NEW
- -We are thrilled to announce the launch of Ultralytics YOLOv8 🚀, our NEW cutting-edge, state-of-the-art (SOTA) model -released at **[https://github.com/ultralytics/ultralytics](https://github.com/ultralytics/ultralytics)**. -YOLOv8 is designed to be fast, accurate, and easy to use, making it an excellent choice for a wide range of -object detection, image segmentation and image classification tasks. - -See the [YOLOv8 Docs](https://docs.ultralytics.com) for details and get started with: - -[![PyPI version](https://badge.fury.io/py/ultralytics.svg)](https://badge.fury.io/py/ultralytics) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://pepy.tech/project/ultralytics) - -```bash -pip install ultralytics -``` - -
- - -
- -##
Documentation
- -See the [YOLOv5 Docs](https://docs.ultralytics.com/yolov5) for full documentation on training, testing and deployment. See below for quickstart examples. - -
-Install - -Clone repo and install [requirements.txt](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) in a -[**Python>=3.8.0**](https://www.python.org/) environment, including -[**PyTorch>=1.8**](https://pytorch.org/get-started/locally/). - -```bash -git clone https://github.com/ultralytics/yolov5 # clone -cd yolov5 -pip install -r requirements.txt # install -``` - -
- -
-Inference - -YOLOv5 [PyTorch Hub](https://docs.ultralytics.com/yolov5/tutorials/pytorch_hub_model_loading) inference. [Models](https://github.com/ultralytics/yolov5/tree/master/models) download automatically from the latest -YOLOv5 [release](https://github.com/ultralytics/yolov5/releases). - -```python -import torch - -# Model -model = torch.hub.load("ultralytics/yolov5", "yolov5s") # or yolov5n - yolov5x6, custom - -# Images -img = "https://ultralytics.com/images/zidane.jpg" # or file, Path, PIL, OpenCV, numpy, list - -# Inference -results = model(img) - -# Results -results.print() # or .show(), .save(), .crop(), .pandas(), etc. -``` - -
- -
-Inference with detect.py - -`detect.py` runs inference on a variety of sources, downloading [models](https://github.com/ultralytics/yolov5/tree/master/models) automatically from -the latest YOLOv5 [release](https://github.com/ultralytics/yolov5/releases) and saving results to `runs/detect`. - -```bash -python detect.py --weights yolov5s.pt --source 0 # webcam - img.jpg # image - vid.mp4 # video - screen # screenshot - path/ # directory - list.txt # list of images - list.streams # list of streams - 'path/*.jpg' # glob - 'https://youtu.be/LNwODJXcvt4' # YouTube - 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream -``` - -
- -
-Training - -The commands below reproduce YOLOv5 [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) -results. [Models](https://github.com/ultralytics/yolov5/tree/master/models) -and [datasets](https://github.com/ultralytics/yolov5/tree/master/data) download automatically from the latest -YOLOv5 [release](https://github.com/ultralytics/yolov5/releases). Training times for YOLOv5n/s/m/l/x are -1/2/4/6/8 days on a V100 GPU ([Multi-GPU](https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training) times faster). Use the -largest `--batch-size` possible, or pass `--batch-size -1` for -YOLOv5 [AutoBatch](https://github.com/ultralytics/yolov5/pull/5092). Batch sizes shown for V100-16GB. +To use HIC-YOLOv5, you can specify the configuration file with the `--cfg` argument. An example command for training might look like this: ```bash -python train.py --data coco.yaml --epochs 300 --weights '' --cfg yolov5n.yaml --batch-size 128 - yolov5s 64 - yolov5m 40 - yolov5l 24 - yolov5x 16 +python train.py --img-size 640 --batch 16 --epochs 100 --data data/coco.yaml --cfg models/yolo5m-cbam-involution.yaml ``` - - -
- -
-Tutorials - -- [Train Custom Data](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data) 🚀 RECOMMENDED -- [Tips for Best Training Results](https://docs.ultralytics.com/yolov5/tutorials/tips_for_best_training_results) ☘ī¸ -- [Multi-GPU Training](https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training) -- [PyTorch Hub](https://docs.ultralytics.com/yolov5/tutorials/pytorch_hub_model_loading) 🌟 NEW -- [TFLite, ONNX, CoreML, TensorRT Export](https://docs.ultralytics.com/yolov5/tutorials/model_export) 🚀 -- [NVIDIA Jetson platform Deployment](https://docs.ultralytics.com/yolov5/tutorials/running_on_jetson_nano) 🌟 NEW -- [Test-Time Augmentation (TTA)](https://docs.ultralytics.com/yolov5/tutorials/test_time_augmentation) -- [Model Ensembling](https://docs.ultralytics.com/yolov5/tutorials/model_ensembling) -- [Model Pruning/Sparsity](https://docs.ultralytics.com/yolov5/tutorials/model_pruning_and_sparsity) -- [Hyperparameter Evolution](https://docs.ultralytics.com/yolov5/tutorials/hyperparameter_evolution) -- [Transfer Learning with Frozen Layers](https://docs.ultralytics.com/yolov5/tutorials/transfer_learning_with_frozen_layers) -- [Architecture Summary](https://docs.ultralytics.com/yolov5/tutorials/architecture_description) 🌟 NEW -- [Roboflow for Datasets, Labeling, and Active Learning](https://docs.ultralytics.com/yolov5/tutorials/roboflow_datasets_integration) -- [ClearML Logging](https://docs.ultralytics.com/yolov5/tutorials/clearml_logging_integration) 🌟 NEW -- [YOLOv5 with Neural Magic's Deepsparse](https://docs.ultralytics.com/yolov5/tutorials/neural_magic_pruning_quantization) 🌟 NEW -- [Comet Logging](https://docs.ultralytics.com/yolov5/tutorials/comet_logging_integration) 🌟 NEW - -
- -##
Integrations
- -
- - -
-
- -
- - - - - - - - - - - -
- -| Roboflow | ClearML ⭐ NEW | Comet ⭐ NEW | Neural Magic ⭐ NEW | -| :--------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------: | -| Label and export your custom datasets directly to YOLOv5 for training with [Roboflow](https://roboflow.com/?ref=ultralytics) | Automatically track, visualize and even remotely train YOLOv5 using [ClearML](https://cutt.ly/yolov5-readme-clearml) (open-source!) | Free forever, [Comet](https://bit.ly/yolov5-readme-comet2) lets you save YOLOv5 models, resume training, and interactively visualise and debug predictions | Run YOLOv5 inference up to 6x faster with [Neural Magic DeepSparse](https://bit.ly/yolov5-neuralmagic) | - -##
Ultralytics HUB
- -Experience seamless AI with [Ultralytics HUB](https://bit.ly/ultralytics_hub) ⭐, the all-in-one solution for data visualization, YOLOv5 and YOLOv8 🚀 model training and deployment, without any coding. Transform images into actionable insights and bring your AI visions to life with ease using our cutting-edge platform and user-friendly [Ultralytics App](https://ultralytics.com/app_install). Start your journey for **Free** now! - - - - -##
Why YOLOv5
- -YOLOv5 has been designed to be super easy to get started and simple to learn. We prioritize real-world results. - -

-
- YOLOv5-P5 640 Figure - -

-
-
- Figure Notes - -- **COCO AP val** denotes mAP@0.5:0.95 metric measured on the 5000-image [COCO val2017](http://cocodataset.org) dataset over various inference sizes from 256 to 1536. -- **GPU Speed** measures average inference time per image on [COCO val2017](http://cocodataset.org) dataset using a [AWS p3.2xlarge](https://aws.amazon.com/ec2/instance-types/p3/) V100 instance at batch-size 32. -- **EfficientDet** data from [google/automl](https://github.com/google/automl) at batch size 8. -- **Reproduce** by `python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5n6.pt yolov5s6.pt yolov5m6.pt yolov5l6.pt yolov5x6.pt` - -
- -### Pretrained Checkpoints - -| Model | size
(pixels) | mAPval
50-95 | mAPval
50 | Speed
CPU b1
(ms) | Speed
V100 b1
(ms) | Speed
V100 b32
(ms) | params
(M) | FLOPs
@640 (B) | -| ----------------------------------------------------------------------------------------------- | --------------------- | -------------------- | ----------------- | ---------------------------- | ----------------------------- | ------------------------------ | ------------------ | ---------------------- | -| [YOLOv5n](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n.pt) | 640 | 28.0 | 45.7 | **45** | **6.3** | **0.6** | **1.9** | **4.5** | -| [YOLOv5s](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s.pt) | 640 | 37.4 | 56.8 | 98 | 6.4 | 0.9 | 7.2 | 16.5 | -| [YOLOv5m](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m.pt) | 640 | 45.4 | 64.1 | 224 | 8.2 | 1.7 | 21.2 | 49.0 | -| [YOLOv5l](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5l.pt) | 640 | 49.0 | 67.3 | 430 | 10.1 | 2.7 | 46.5 | 109.1 | -| [YOLOv5x](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x.pt) | 640 | 50.7 | 68.9 | 766 | 12.1 | 4.8 | 86.7 | 205.7 | -| | | | | | | | | | -| [YOLOv5n6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n6.pt) | 1280 | 36.0 | 54.4 | 153 | 8.1 | 2.1 | 3.2 | 4.6 | -| [YOLOv5s6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s6.pt) | 1280 | 44.8 | 63.7 | 385 | 8.2 | 3.6 | 12.6 | 16.8 | -| [YOLOv5m6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m6.pt) | 1280 | 51.3 | 69.3 | 887 | 11.1 | 6.8 | 35.7 | 50.0 | -| [YOLOv5l6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5l6.pt) | 1280 | 53.7 | 71.3 | 1784 | 15.8 | 10.5 | 76.8 | 111.4 | -| [YOLOv5x6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x6.pt)
+ [TTA] | 1280
1536 | 55.0
**55.8** | 72.7
**72.7** | 3136
- | 26.2
- | 19.4
- | 140.7
- | 209.8
- | - -
- Table Notes - -- All checkpoints are trained to 300 epochs with default settings. Nano and Small models use [hyp.scratch-low.yaml](https://github.com/ultralytics/yolov5/blob/master/data/hyps/hyp.scratch-low.yaml) hyps, all others use [hyp.scratch-high.yaml](https://github.com/ultralytics/yolov5/blob/master/data/hyps/hyp.scratch-high.yaml). -- **mAPval** values are for single-model single-scale on [COCO val2017](http://cocodataset.org) dataset.
Reproduce by `python val.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65` -- **Speed** averaged over COCO val images using a [AWS p3.2xlarge](https://aws.amazon.com/ec2/instance-types/p3/) instance. NMS times (~1 ms/img) not included.
Reproduce by `python val.py --data coco.yaml --img 640 --task speed --batch 1` -- **TTA** [Test Time Augmentation](https://docs.ultralytics.com/yolov5/tutorials/test_time_augmentation) includes reflection and scale augmentations.
Reproduce by `python val.py --data coco.yaml --img 1536 --iou 0.7 --augment` - -
- -##
Segmentation
- -Our new YOLOv5 [release v7.0](https://github.com/ultralytics/yolov5/releases/v7.0) instance segmentation models are the fastest and most accurate in the world, beating all current [SOTA benchmarks](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco). We've made them super simple to train, validate and deploy. See full details in our [Release Notes](https://github.com/ultralytics/yolov5/releases/v7.0) and visit our [YOLOv5 Segmentation Colab Notebook](https://github.com/ultralytics/yolov5/blob/master/segment/tutorial.ipynb) for quickstart tutorials. - -
- Segmentation Checkpoints - -
- - -
- -We trained YOLOv5 segmentations models on COCO for 300 epochs at image size 640 using A100 GPUs. We exported all models to ONNX FP32 for CPU speed tests and to TensorRT FP16 for GPU speed tests. We ran all speed tests on Google [Colab Pro](https://colab.research.google.com/signup) notebooks for easy reproducibility. - -| Model | size
(pixels) | mAPbox
50-95 | mAPmask
50-95 | Train time
300 epochs
A100 (hours) | Speed
ONNX CPU
(ms) | Speed
TRT A100
(ms) | params
(M) | FLOPs
@640 (B) | -| ------------------------------------------------------------------------------------------ | --------------------- | -------------------- | --------------------- | --------------------------------------------- | ------------------------------ | ------------------------------ | ------------------ | ---------------------- | -| [YOLOv5n-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n-seg.pt) | 640 | 27.6 | 23.4 | 80:17 | **62.7** | **1.2** | **2.0** | **7.1** | -| [YOLOv5s-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s-seg.pt) | 640 | 37.6 | 31.7 | 88:16 | 173.3 | 1.4 | 7.6 | 26.4 | -| [YOLOv5m-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m-seg.pt) | 640 | 45.0 | 37.1 | 108:36 | 427.0 | 2.2 | 22.0 | 70.8 | -| [YOLOv5l-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5l-seg.pt) | 640 | 49.0 | 39.9 | 66:43 (2x) | 857.4 | 2.9 | 47.9 | 147.7 | -| [YOLOv5x-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x-seg.pt) | 640 | **50.7** | **41.4** | 62:56 (3x) | 1579.2 | 4.5 | 88.8 | 265.7 | - -- All checkpoints are trained to 300 epochs with SGD optimizer with `lr0=0.01` and `weight_decay=5e-5` at image size 640 and all default settings.
Runs logged to https://wandb.ai/glenn-jocher/YOLOv5_v70_official -- **Accuracy** values are for single-model single-scale on COCO dataset.
Reproduce by `python segment/val.py --data coco.yaml --weights yolov5s-seg.pt` -- **Speed** averaged over 100 inference images using a [Colab Pro](https://colab.research.google.com/signup) A100 High-RAM instance. Values indicate inference speed only (NMS adds about 1ms per image).
Reproduce by `python segment/val.py --data coco.yaml --weights yolov5s-seg.pt --batch 1` -- **Export** to ONNX at FP32 and TensorRT at FP16 done with `export.py`.
Reproduce by `python export.py --weights yolov5s-seg.pt --include engine --device 0 --half` - -
- -
- Segmentation Usage Examples  Open In Colab - -### Train - -YOLOv5 segmentation training supports auto-download COCO128-seg segmentation dataset with `--data coco128-seg.yaml` argument and manual download of COCO-segments dataset with `bash data/scripts/get_coco.sh --train --val --segments` and then `python train.py --data coco.yaml`. - -```bash -# Single-GPU -python segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 - -# Multi-GPU DDP -python -m torch.distributed.run --nproc_per_node 4 --master_port 1 segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 --device 0,1,2,3 -``` - -### Val - -Validate YOLOv5s-seg mask mAP on COCO dataset: - -```bash -bash data/scripts/get_coco.sh --val --segments # download COCO val segments split (780MB, 5000 images) -python segment/val.py --weights yolov5s-seg.pt --data coco.yaml --img 640 # validate -``` - -### Predict - -Use pretrained YOLOv5m-seg.pt to predict bus.jpg: - -```bash -python segment/predict.py --weights yolov5m-seg.pt --source data/images/bus.jpg -``` - -```python -model = torch.hub.load( - "ultralytics/yolov5", "custom", "yolov5m-seg.pt" -) # load from PyTorch Hub (WARNING: inference not yet supported) -``` - -| ![zidane](https://user-images.githubusercontent.com/26833433/203113421-decef4c4-183d-4a0a-a6c2-6435b33bc5d3.jpg) | ![bus](https://user-images.githubusercontent.com/26833433/203113416-11fe0025-69f7-4874-a0a6-65d0bfe2999a.jpg) | -| ---------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------- | - -### Export - -Export YOLOv5s-seg model to ONNX and TensorRT: - -```bash -python export.py --weights yolov5s-seg.pt --include onnx engine --img 640 --device 0 -``` - -
- -##
Classification
- -YOLOv5 [release v6.2](https://github.com/ultralytics/yolov5/releases) brings support for classification model training, validation and deployment! See full details in our [Release Notes](https://github.com/ultralytics/yolov5/releases/v6.2) and visit our [YOLOv5 Classification Colab Notebook](https://github.com/ultralytics/yolov5/blob/master/classify/tutorial.ipynb) for quickstart tutorials. - -
- Classification Checkpoints - -
- -We trained YOLOv5-cls classification models on ImageNet for 90 epochs using a 4xA100 instance, and we trained ResNet and EfficientNet models alongside with the same default training settings to compare. We exported all models to ONNX FP32 for CPU speed tests and to TensorRT FP16 for GPU speed tests. We ran all speed tests on Google [Colab Pro](https://colab.research.google.com/signup) for easy reproducibility. - -| Model | size
(pixels) | acc
top1 | acc
top5 | Training
90 epochs
4xA100 (hours) | Speed
ONNX CPU
(ms) | Speed
TensorRT V100
(ms) | params
(M) | FLOPs
@224 (B) | -| -------------------------------------------------------------------------------------------------- | --------------------- | ---------------- | ---------------- | -------------------------------------------- | ------------------------------ | ----------------------------------- | ------------------ | ---------------------- | -| [YOLOv5n-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n-cls.pt) | 224 | 64.6 | 85.4 | 7:59 | **3.3** | **0.5** | **2.5** | **0.5** | -| [YOLOv5s-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s-cls.pt) | 224 | 71.5 | 90.2 | 8:09 | 6.6 | 0.6 | 5.4 | 1.4 | -| [YOLOv5m-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m-cls.pt) | 224 | 75.9 | 92.9 | 10:06 | 15.5 | 0.9 | 12.9 | 3.9 | -| [YOLOv5l-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5l-cls.pt) | 224 | 78.0 | 94.0 | 11:56 | 26.9 | 1.4 | 26.5 | 8.5 | -| [YOLOv5x-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x-cls.pt) | 224 | **79.0** | **94.4** | 15:04 | 54.3 | 1.8 | 48.1 | 15.9 | -| | | | | | | | | | -| [ResNet18](https://github.com/ultralytics/yolov5/releases/download/v7.0/resnet18.pt) | 224 | 70.3 | 89.5 | **6:47** | 11.2 | 0.5 | 11.7 | 3.7 | -| [ResNet34](https://github.com/ultralytics/yolov5/releases/download/v7.0/resnet34.pt) | 224 | 73.9 | 91.8 | 8:33 | 20.6 | 0.9 | 21.8 | 7.4 | -| [ResNet50](https://github.com/ultralytics/yolov5/releases/download/v7.0/resnet50.pt) | 224 | 76.8 | 93.4 | 11:10 | 23.4 | 1.0 | 25.6 | 8.5 | -| [ResNet101](https://github.com/ultralytics/yolov5/releases/download/v7.0/resnet101.pt) | 224 | 78.5 | 94.3 | 17:10 | 42.1 | 1.9 | 44.5 | 15.9 | -| | | | | | | | | | -| [EfficientNet_b0](https://github.com/ultralytics/yolov5/releases/download/v7.0/efficientnet_b0.pt) | 224 | 75.1 | 92.4 | 13:03 | 12.5 | 1.3 | 5.3 | 1.0 | -| [EfficientNet_b1](https://github.com/ultralytics/yolov5/releases/download/v7.0/efficientnet_b1.pt) | 224 | 76.4 | 93.2 | 17:04 | 14.9 | 1.6 | 7.8 | 1.5 | -| [EfficientNet_b2](https://github.com/ultralytics/yolov5/releases/download/v7.0/efficientnet_b2.pt) | 224 | 76.6 | 93.4 | 17:10 | 15.9 | 1.6 | 9.1 | 1.7 | -| [EfficientNet_b3](https://github.com/ultralytics/yolov5/releases/download/v7.0/efficientnet_b3.pt) | 224 | 77.7 | 94.0 | 19:19 | 18.9 | 1.9 | 12.2 | 2.4 | - -
- Table Notes (click to expand) - -- All checkpoints are trained to 90 epochs with SGD optimizer with `lr0=0.001` and `weight_decay=5e-5` at image size 224 and all default settings.
Runs logged to https://wandb.ai/glenn-jocher/YOLOv5-Classifier-v6-2 -- **Accuracy** values are for single-model single-scale on [ImageNet-1k](https://www.image-net.org/index.php) dataset.
Reproduce by `python classify/val.py --data ../datasets/imagenet --img 224` -- **Speed** averaged over 100 inference images using a Google [Colab Pro](https://colab.research.google.com/signup) V100 High-RAM instance.
Reproduce by `python classify/val.py --data ../datasets/imagenet --img 224 --batch 1` -- **Export** to ONNX at FP32 and TensorRT at FP16 done with `export.py`.
Reproduce by `python export.py --weights yolov5s-cls.pt --include engine onnx --imgsz 224` - -
-
- -
- Classification Usage Examples  Open In Colab - -### Train - -YOLOv5 classification training supports auto-download of MNIST, Fashion-MNIST, CIFAR10, CIFAR100, Imagenette, Imagewoof, and ImageNet datasets with the `--data` argument. To start training on MNIST for example use `--data mnist`. - -```bash -# Single-GPU -python classify/train.py --model yolov5s-cls.pt --data cifar100 --epochs 5 --img 224 --batch 128 - -# Multi-GPU DDP -python -m torch.distributed.run --nproc_per_node 4 --master_port 1 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3 -``` - -### Val - -Validate YOLOv5m-cls accuracy on ImageNet-1k dataset: - -```bash -bash data/scripts/get_imagenet.sh --val # download ImageNet val split (6.3G, 50000 images) -python classify/val.py --weights yolov5m-cls.pt --data ../datasets/imagenet --img 224 # validate -``` - -### Predict - -Use pretrained YOLOv5s-cls.pt to predict bus.jpg: - -```bash -python classify/predict.py --weights yolov5s-cls.pt --source data/images/bus.jpg -``` - -```python -model = torch.hub.load( - "ultralytics/yolov5", "custom", "yolov5s-cls.pt" -) # load from PyTorch Hub -``` - -### Export - -Export a group of trained YOLOv5s-cls, ResNet and EfficientNet models to ONNX and TensorRT: - -```bash -python export.py --weights yolov5s-cls.pt resnet50.pt efficientnet_b0.pt --include onnx engine --img 224 -``` - -
- -##
Environments
- -Get started in seconds with our verified environments. Click each icon below for details. - -
- - - - - - - - - - - - - - - - - -
- -##
Contribute
- -We love your input! We want to make contributing to YOLOv5 as easy and transparent as possible. Please see our [Contributing Guide](https://docs.ultralytics.com/help/contributing/) to get started, and fill out the [YOLOv5 Survey](https://ultralytics.com/survey?utm_source=github&utm_medium=social&utm_campaign=Survey) to send us feedback on your experiences. Thank you to all our contributors! - - +- `--img-size`: Specifies the input image size. +- `--batch`: Sets the batch size for training. +- `--epochs`: Defines the number of training epochs. +- `--data`: Specifies the data configuration file. +- `--cfg`: Points to the configuration file for HIC-YOLOv5. In this case, it's the `models/yolo5m-cbam-involution.yaml`. - - +## Testing for Multi-GPU Training (TODO) -##
License
+I am actively working on adding support for multi-GPU training. Please stay tuned for updates on testing and training with multiple GPUs. -Ultralytics offers two licensing options to accommodate diverse use cases: +## Acknowledgments -- **AGPL-3.0 License**: This [OSI-approved](https://opensource.org/licenses/) open-source license is ideal for students and enthusiasts, promoting open collaboration and knowledge sharing. See the [LICENSE](https://github.com/ultralytics/yolov5/blob/master/LICENSE) file for more details. -- **Enterprise License**: Designed for commercial use, this license permits seamless integration of Ultralytics software and AI models into commercial goods and services, bypassing the open-source requirements of AGPL-3.0. If your scenario involves embedding our solutions into a commercial offering, reach out through [Ultralytics Licensing](https://ultralytics.com/license). +I want to express our gratitude to the authors of the paper "HIC-YOLOv5: Improved YOLOv5 For Small Object Detection" for their contributions, which inspired the development of HIC-YOLOv5. -##
Contact
+## License -For YOLOv5 bug reports and feature requests please visit [GitHub Issues](https://github.com/ultralytics/yolov5/issues), and join our [Discord](https://ultralytics.com/discord) community for questions and discussions! +HIC-YOLOv5 is released under the MIT License. Please refer to the LICENSE file for more details. -
-
- - - - - - - - - - - - - - - - - - - - -
+For additional information and updates, please refer to the [YOLOv5 GitHub repository](https://github.com/ultralytics/yolov5). -[tta]: https://docs.ultralytics.com/yolov5/tutorials/test_time_augmentation +**Note:** Be sure to refer to the official [YOLOv5 repository](https://github.com/ultralytics/yolov5) for the latest updates and documentation. \ No newline at end of file From b1b1ab9377a8fce45647a56d9edea04c96a951fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:10:51 +0000 Subject: [PATCH 06/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 2 +- models/common.py | 29 +++++++++++++---------------- models/yolo5m-cbam-involution.yaml | 2 +- utils/general.py | 3 ++- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 5c4580eaf323..72420aefb155 100644 --- a/README.md +++ b/README.md @@ -38,4 +38,4 @@ HIC-YOLOv5 is released under the MIT License. Please refer to the LICENSE file f For additional information and updates, please refer to the [YOLOv5 GitHub repository](https://github.com/ultralytics/yolov5). -**Note:** Be sure to refer to the official [YOLOv5 repository](https://github.com/ultralytics/yolov5) for the latest updates and documentation. \ No newline at end of file +**Note:** Be sure to refer to the official [YOLOv5 repository](https://github.com/ultralytics/yolov5) for the latest updates and documentation. diff --git a/models/common.py b/models/common.py index cfcde6f8754f..37885dc87a1b 100644 --- a/models/common.py +++ b/models/common.py @@ -882,7 +882,9 @@ def forward(self, x): x = torch.cat(x, 1) return self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) + class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=16): """ Initialize the Channel Attention module. @@ -891,7 +893,7 @@ def __init__(self, in_planes, ratio=16): in_planes (int): Number of input channels. ratio (int): Reduction ratio for the hidden channels in the channel attention block. """ - super(ChannelAttention, self).__init__() + super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) @@ -910,7 +912,7 @@ def forward(self, x): out (torch.Tensor): Output tensor after applying channel attention. """ with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter('ignore') avg_out = self.f2(self.relu(self.f1(self.avg_pool(x)))) max_out = self.f2(self.relu(self.f1(self.max_pool(x)))) out = self.sigmoid(avg_out + max_out) @@ -918,6 +920,7 @@ def forward(self, x): class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): """ Initialize the Spatial Attention module. @@ -925,7 +928,7 @@ def __init__(self, kernel_size=7): Args: kernel_size (int): Size of the convolutional kernel for spatial attention. """ - super(SpatialAttention, self).__init__() + super().__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) @@ -942,7 +945,7 @@ def forward(self, x): out (torch.Tensor): Output tensor after applying spatial attention. """ with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter('ignore') avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) @@ -965,7 +968,7 @@ def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16): e (float): Expansion factor for hidden channels. ratio (int): Reduction ratio for the hidden channels in the channel attention block. """ - super(CBAM, self).__init__() + super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_, c2, 3, 1, g=g) @@ -984,14 +987,13 @@ def forward(self, x): out (torch.Tensor): Output tensor after applying the CBAM bottleneck. """ with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter('ignore') x2 = self.cv2(self.cv1(x)) out = self.channel_attention(x2) * x2 out = self.spatial_attention(out) * out return x + out if self.add else out - class Involution(nn.Module): def __init__(self, c1, c2, kernel_size, stride): @@ -1004,19 +1006,15 @@ def __init__(self, c1, c2, kernel_size, stride): kernel_size (int): Size of the involution kernel. stride (int): Stride for the involution operation. """ - super(Involution, self).__init__() + super().__init__() self.kernel_size = kernel_size self.stride = stride self.c1 = c1 reduction_ratio = 1 self.group_channels = 16 self.groups = self.c1 // self.group_channels - self.conv1 = Conv( - c1, c1 // reduction_ratio, 1) - self.conv2 = Conv( - c1 // reduction_ratio, - kernel_size ** 2 * self.groups, - 1, 1) + self.conv1 = Conv(c1, c1 // reduction_ratio, 1) + self.conv2 = Conv(c1 // reduction_ratio, kernel_size ** 2 * self.groups, 1, 1) if stride > 1: self.avgpool = nn.AvgPool2d(stride, stride) @@ -1033,7 +1031,7 @@ def forward(self, x): out (torch.Tensor): Output tensor after applying the involution operation. """ with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter('ignore') weight = self.conv2(x) b, c, h, w = weight.shape weight = weight.view(b, self.groups, self.kernel_size ** 2, h, w).unsqueeze(2) @@ -1041,4 +1039,3 @@ def forward(self, x): out = (weight * out).sum(dim=3).view(b, self.c1, h, w) return out - diff --git a/models/yolo5m-cbam-involution.yaml b/models/yolo5m-cbam-involution.yaml index 326940f7d6a8..9ac132e1cd78 100644 --- a/models/yolo5m-cbam-involution.yaml +++ b/models/yolo5m-cbam-involution.yaml @@ -57,4 +57,4 @@ head: [-1, 3, C3, [1024, False]], # 32 20*20 p5 head [[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5) - ] \ No newline at end of file + ] diff --git a/utils/general.py b/utils/general.py index ba799a174039..f6e42d6de400 100644 --- a/utils/general.py +++ b/utils/general.py @@ -265,7 +265,8 @@ def init_seeds(seed=0, deterministic=False): # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 # torch.use_deterministic_algorithms(True) - torch.use_deterministic_algorithms(False, warn_only= True) #since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training + torch.use_deterministic_algorithms( + False, warn_only=True) #since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training torch.backends.cudnn.deterministic = True os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed) From 2ee59f68326898d6782c5d7f8032d2c3d61a00e8 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Tue, 24 Oct 2023 19:50:03 +0530 Subject: [PATCH 07/21] Update general.py refactoring to meet inline comment rules Signed-off-by: Aakash Singh --- utils/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/general.py b/utils/general.py index f6e42d6de400..863a44588462 100644 --- a/utils/general.py +++ b/utils/general.py @@ -264,9 +264,9 @@ def init_seeds(seed=0, deterministic=False): torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 - # torch.use_deterministic_algorithms(True) + # since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training torch.use_deterministic_algorithms( - False, warn_only=True) #since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training + False, warn_only=True) torch.backends.cudnn.deterministic = True os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed) From 79112dfb02b0e8417b8bbb1687954e948ce2c677 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Tue, 24 Oct 2023 19:52:36 +0530 Subject: [PATCH 08/21] Update common.py added few comments Signed-off-by: Aakash Singh --- models/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/common.py b/models/common.py index 37885dc87a1b..dd5567eec2d3 100644 --- a/models/common.py +++ b/models/common.py @@ -882,7 +882,7 @@ def forward(self, x): x = torch.cat(x, 1) return self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) - +# contributed by @aash1999 class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): @@ -918,7 +918,7 @@ def forward(self, x): out = self.sigmoid(avg_out + max_out) return out - +# contributed by @aash1999 class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): @@ -952,7 +952,7 @@ def forward(self, x): x = self.conv(x) return self.sigmoid(x) - +# contributed by @aash1999 class CBAM(nn.Module): # ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16): @@ -993,7 +993,7 @@ def forward(self, x): out = self.spatial_attention(out) * out return x + out if self.add else out - +# contributed by @aash1999 class Involution(nn.Module): def __init__(self, c1, c2, kernel_size, stride): From 3d46323e556c73859577e37a721e65d7916f835b Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Tue, 24 Oct 2023 20:30:55 +0530 Subject: [PATCH 09/21] adding hyp and model files as mentioned in paper --- data/hyps/cbam.hyp.yaml | 34 ++++++++++++++++ models/yolov5s-cbam-involution.yaml | 60 +++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 data/hyps/cbam.hyp.yaml create mode 100644 models/yolov5s-cbam-involution.yaml diff --git a/data/hyps/cbam.hyp.yaml b/data/hyps/cbam.hyp.yaml new file mode 100644 index 000000000000..f46921dc66e9 --- /dev/null +++ b/data/hyps/cbam.hyp.yaml @@ -0,0 +1,34 @@ +# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license +# Hyperparameters for low-augmentation COCO training from scratch +# python train.py --batch 64 --cfg yolov5n6.yaml --weights '' --data coco.yaml --img 640 --epochs 300 --linear +# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials + +lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) +lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) +momentum: 0.937 # SGD momentum/Adam beta1 +weight_decay: 0.0005 # optimizer weight decay 5e-4 +warmup_epochs: 3.0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.1 # warmup initial bias lr +box: 0.05 # box loss gain +cls: 0.25 # cls loss gain +cls_pw: 1.0 # cls BCELoss positive_weight +obj: 0.5 # obj loss gain (scale with pixels) +obj_pw: 1.0 # obj BCELoss positive_weight +iou_t: 0.20 # IoU training threshold +anchor_t: 4.0 # anchor-multiple threshold +# anchors: 3 # anchors per output layer (0 to ignore) +fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) +hsv_h: 0.4 # image HSV-Hue augmentation (fraction) +hsv_s: 0.3 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.5 # image HSV-Value augmentation (fraction) +degrees: 0.2 # image rotation (+/- deg) +translate: 0.1 # image translation (+/- fraction) +scale: 0.4 # image scale (+/- gain) +shear: 0.0 # image shear (+/- deg) +perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # image flip up-down (probability) +fliplr: 0.5 # image flip left-right (probability) +mosaic: 1.0 # image mosaic (probability)s +mixup: 0.2 # image mixup (probability) +copy_paste: 0.1 # segment copy-paste (probability) diff --git a/models/yolov5s-cbam-involution.yaml b/models/yolov5s-cbam-involution.yaml new file mode 100644 index 000000000000..1e5ab9041ca6 --- /dev/null +++ b/models/yolov5s-cbam-involution.yaml @@ -0,0 +1,60 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license + +# Parameters +nc: 10 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple +anchors: + - [2.9434,4.0435, 3.8626,8.5592, 6.8534, 5.9391] + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 3, CBAMBottleneck, [1024, 3]], + [-1, 1, SPPF, [1024, 5]], # 10 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Involution, [1024, 1, 1]], + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 15 + + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [512, False]], # 19 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 2], 1, Concat, [1]], + [-1, 3, C3, [256, False]], # 23 160*160 p2 head + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 19], 1, Concat, [1]], + [-1, 3, C3, [512, False]], # 26 80*80 p3 head + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 15], 1, Concat, [1]], + [-1, 3, C3, [256, False]], # 29 40*40 p4 head + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 11], 1, Concat, [1]], + [-1, 3, C3, [1024, False]], # 32 20*20 p5 head + + [[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5) + ] \ No newline at end of file From 1204c74e78a17c58cc563660d98c61763a54b22e Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Tue, 24 Oct 2023 20:43:39 +0530 Subject: [PATCH 10/21] Delete models/yolo5m-cbam-involution.yaml Signed-off-by: Aakash Singh --- models/yolo5m-cbam-involution.yaml | 60 ------------------------------ 1 file changed, 60 deletions(-) delete mode 100644 models/yolo5m-cbam-involution.yaml diff --git a/models/yolo5m-cbam-involution.yaml b/models/yolo5m-cbam-involution.yaml deleted file mode 100644 index 9ac132e1cd78..000000000000 --- a/models/yolo5m-cbam-involution.yaml +++ /dev/null @@ -1,60 +0,0 @@ -# YOLOv5 🚀 by Ultralytics, GPL-3.0 license - -# Parameters -nc: 10 # number of classes -depth_multiple: 0.33 # model depth multiple -width_multiple: 0.50 # layer channel multiple -anchors: - - [2.9434,4.0435, 3.8626,8.5592, 6.8534, 5.9391] - - [10,13, 16,30, 33,23] # P3/8 - - [30,61, 62,45, 59,119] # P4/16 - - [116,90, 156,198, 373,326] # P5/32 - -# YOLOv5 v6.0 backbone -backbone: - # [from, number, module, args] - [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 - [-1, 3, C3, [128]], - [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 - [-1, 6, C3, [256]], - [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 - [-1, 9, C3, [512]], - [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 - [-1, 3, C3, [1024]], - [-1, 3, CBAM, [1024, 3]], - [-1, 1, SPPF, [1024, 5]], # 10 - ] - -# YOLOv5 v6.0 head -head: - [[-1, 1, Involution, [1024, 1, 1]], - [-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 6], 1, Concat, [1]], # cat backbone P4 - [-1, 3, C3, [512, False]], # 15 - - [-1, 1, Conv, [512, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 4], 1, Concat, [1]], # cat backbone P3 - [-1, 3, C3, [512, False]], # 19 - - [-1, 1, Conv, [256, 1, 1]], - [-1, 1, nn.Upsample, [None, 2, 'nearest']], - [[-1, 2], 1, Concat, [1]], - [-1, 3, C3, [256, False]], # 23 160*160 p2 head - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 19], 1, Concat, [1]], - [-1, 3, C3, [512, False]], # 26 80*80 p3 head - - [-1, 1, Conv, [256, 3, 2]], - [[-1, 15], 1, Concat, [1]], - [-1, 3, C3, [256, False]], # 29 40*40 p4 head - - [-1, 1, Conv, [512, 3, 2]], - [[-1, 11], 1, Concat, [1]], - [-1, 3, C3, [1024, False]], # 32 20*20 p5 head - - [[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5) - ] From 947266a3da3ddd99078e2a26ec5d8b48a60f0fb7 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Tue, 24 Oct 2023 20:45:19 +0530 Subject: [PATCH 11/21] Update general.py removing trailing white space Signed-off-by: Aakash Singh --- utils/general.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/general.py b/utils/general.py index 863a44588462..634085880298 100644 --- a/utils/general.py +++ b/utils/general.py @@ -265,8 +265,7 @@ def init_seeds(seed=0, deterministic=False): # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 # since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training - torch.use_deterministic_algorithms( - False, warn_only=True) + torch.use_deterministic_algorithms(False, warn_only=True) torch.backends.cudnn.deterministic = True os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed) From a56bf8160cc7eba09ed1a3698418d2f6d4e405d7 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Wed, 25 Oct 2023 23:54:05 +0530 Subject: [PATCH 12/21] Update yolov5s-cbam-involution.yaml Typo correction Signed-off-by: Aakash Singh --- models/yolov5s-cbam-involution.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/yolov5s-cbam-involution.yaml b/models/yolov5s-cbam-involution.yaml index 1e5ab9041ca6..9ac132e1cd78 100644 --- a/models/yolov5s-cbam-involution.yaml +++ b/models/yolov5s-cbam-involution.yaml @@ -22,7 +22,7 @@ backbone: [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 [-1, 3, C3, [1024]], - [-1, 3, CBAMBottleneck, [1024, 3]], + [-1, 3, CBAM, [1024, 3]], [-1, 1, SPPF, [1024, 5]], # 10 ] @@ -57,4 +57,4 @@ head: [-1, 3, C3, [1024, False]], # 32 20*20 p5 head [[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5) - ] \ No newline at end of file + ] From 5208303e88647bc8625f7421164205e52a74d1c4 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 00:14:49 +0530 Subject: [PATCH 13/21] Update CITATION.cff Signed-off-by: Aakash Singh --- CITATION.cff | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index c277230d922f..43c5ce03ccbc 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,14 +1,13 @@ cff-version: 1.2.0 preferred-citation: type: software - message: If you use YOLOv5, please cite it as below. + message: If you use YOLOv5-cbam, please cite it as below. authors: - - family-names: Jocher - given-names: Glenn - orcid: "https://orcid.org/0000-0001-5950-6979" - title: "YOLOv5 by Ultralytics" - version: 7.0 - doi: 10.5281/zenodo.3908559 - date-released: 2020-5-29 + - family-names: Aakash + given-names: Singh + orcid: "https://orcid.org/0009-0000-6586-9952" + title: "HIC-Yolov5" + version: 1.0 + date-released: 2023-10-22 license: AGPL-3.0 - url: "https://github.com/ultralytics/yolov5" + url: "https://github.com/aash1999/yolov5-cbam" From ccf26645dcfbf92d1d55f004aa780a6fe42f035a Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 02:31:54 +0530 Subject: [PATCH 14/21] removed trailing spaces in general.py --- utils/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/general.py b/utils/general.py index 634085880298..f31ca46f8cff 100644 --- a/utils/general.py +++ b/utils/general.py @@ -265,7 +265,7 @@ def init_seeds(seed=0, deterministic=False): # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 # since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training - torch.use_deterministic_algorithms(False, warn_only=True) + torch.use_deterministic_algorithms(False, warn_only=True) torch.backends.cudnn.deterministic = True os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed) From 16ed93a10d8d07e813039e3cc2b8089629f10d6e Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 02:35:41 +0530 Subject: [PATCH 15/21] yapf formatting --- models/common.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/common.py b/models/common.py index dd5567eec2d3..55b29d55e6c0 100644 --- a/models/common.py +++ b/models/common.py @@ -365,9 +365,8 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, model.half() if fp16 else model.float() if extra_files['config.txt']: # load metadata dict d = json.loads(extra_files['config.txt'], - object_hook=lambda d: { - int(k) if k.isdigit() else k: v - for k, v in d.items()}) + object_hook=lambda d: {int(k) if k.isdigit() else k: v + for k, v in d.items()}) stride, names = int(d['stride']), d['names'] elif dnn: # ONNX OpenCV DNN LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...') @@ -882,9 +881,9 @@ def forward(self, x): x = torch.cat(x, 1) return self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) + # contributed by @aash1999 class ChannelAttention(nn.Module): - def __init__(self, in_planes, ratio=16): """ Initialize the Channel Attention module. @@ -918,9 +917,9 @@ def forward(self, x): out = self.sigmoid(avg_out + max_out) return out + # contributed by @aash1999 class SpatialAttention(nn.Module): - def __init__(self, kernel_size=7): """ Initialize the Spatial Attention module. @@ -952,6 +951,7 @@ def forward(self, x): x = self.conv(x) return self.sigmoid(x) + # contributed by @aash1999 class CBAM(nn.Module): # ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size @@ -993,9 +993,9 @@ def forward(self, x): out = self.spatial_attention(out) * out return x + out if self.add else out + # contributed by @aash1999 class Involution(nn.Module): - def __init__(self, c1, c2, kernel_size, stride): """ Initialize the Involution module. From 11ddc58b83c6981cc514fcb9a1e5b517bcf7cffb Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 02:44:09 +0530 Subject: [PATCH 16/21] yapf formatting --- export.py | 1 - models/tf.py | 1 - utils/activations.py | 3 --- utils/callbacks.py | 1 - utils/dataloaders.py | 4 ---- utils/torch_utils.py | 1 - utils/triton.py | 1 - 7 files changed, 12 deletions(-) diff --git a/export.py b/export.py index 71e4eb94d1c4..9ef70b255059 100644 --- a/export.py +++ b/export.py @@ -78,7 +78,6 @@ class iOSModel(torch.nn.Module): - def __init__(self, model, im): super().__init__() b, c, h, w = im.shape # batch, channel, height, width diff --git a/models/tf.py b/models/tf.py index 62ba3ebf0782..d361558020ae 100644 --- a/models/tf.py +++ b/models/tf.py @@ -340,7 +340,6 @@ def call(self, x): class TFProto(keras.layers.Layer): - def __init__(self, c1, c_=256, c2=32, w=None): super().__init__() self.cv1 = TFConv(c1, c_, k=3, w=w.cv1) diff --git a/utils/activations.py b/utils/activations.py index e4d4bbde5ec8..40fdc4603a3c 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -33,7 +33,6 @@ def forward(x): class MemoryEfficientMish(nn.Module): # Mish activation memory-efficient class F(torch.autograd.Function): - @staticmethod def forward(ctx, x): ctx.save_for_backward(x) @@ -66,7 +65,6 @@ class AconC(nn.Module): AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter according to "Activate or Not: Learning Customized Activation" . """ - def __init__(self, c1): super().__init__() self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) @@ -83,7 +81,6 @@ class MetaAconC(nn.Module): MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network according to "Activate or Not: Learning Customized Activation" . """ - def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r super().__init__() c2 = max(r, c1 // r) diff --git a/utils/callbacks.py b/utils/callbacks.py index c90fa824cdb4..b5a6c1c096fe 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -10,7 +10,6 @@ class Callbacks: """" Handles all registered callbacks for YOLOv5 Hooks """ - def __init__(self): # Define the available callbacks self._callbacks = { diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 1fbd0361ded4..472b372b86f6 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -158,7 +158,6 @@ class InfiniteDataLoader(dataloader.DataLoader): Uses same syntax as vanilla DataLoader """ - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) @@ -178,7 +177,6 @@ class _RepeatSampler: Args: sampler (Sampler) """ - def __init__(self, sampler): self.sampler = sampler @@ -1054,7 +1052,6 @@ class HUBDatasetStats(): stats.get_json(save=False) stats.process_images() """ - def __init__(self, path='coco128.yaml', autodownload=False): # Initialize class zipped, data_dir, yaml_path = self._unzip(Path(path)) @@ -1169,7 +1166,6 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): transform: torchvision transforms, used by default album_transform: Albumentations transforms, used if installed """ - def __init__(self, root, augment, imgsz, cache=False): super().__init__(root=root) self.torch_transforms = classify_transforms(imgsz) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 13a356f3238c..c0608ebe519d 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -406,7 +406,6 @@ class ModelEMA: Keeps a moving average of everything in the model state_dict (parameters and buffers) For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage """ - def __init__(self, model, decay=0.9999, tau=2000, updates=0): # Create EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA diff --git a/utils/triton.py b/utils/triton.py index b5153dad940d..57a07824cf2a 100644 --- a/utils/triton.py +++ b/utils/triton.py @@ -13,7 +13,6 @@ class TritonRemoteModel: be configured to communicate over GRPC or HTTP. It accepts Torch Tensors as input and returns them as outputs. """ - def __init__(self, url: str): """ Keyword arguments: From 02bf25611ab43ad050956c65af2852a68cb6ffbe Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 03:06:53 +0530 Subject: [PATCH 17/21] Delete CITATION.cff Signed-off-by: Aakash Singh --- CITATION.cff | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 CITATION.cff diff --git a/CITATION.cff b/CITATION.cff deleted file mode 100644 index 43c5ce03ccbc..000000000000 --- a/CITATION.cff +++ /dev/null @@ -1,13 +0,0 @@ -cff-version: 1.2.0 -preferred-citation: - type: software - message: If you use YOLOv5-cbam, please cite it as below. - authors: - - family-names: Aakash - given-names: Singh - orcid: "https://orcid.org/0009-0000-6586-9952" - title: "HIC-Yolov5" - version: 1.0 - date-released: 2023-10-22 - license: AGPL-3.0 - url: "https://github.com/aash1999/yolov5-cbam" From 1f85ade4bab857b90893ab3f22147affc95427e9 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 03:28:31 +0530 Subject: [PATCH 18/21] reverting the files to commit 4d687c8 --- CITATION.cff | 14 ++ README.md | 504 +++++++++++++++++++++++++++++++++++++++++-- export.py | 1 + models/tf.py | 1 + utils/activations.py | 3 + utils/callbacks.py | 1 + utils/dataloaders.py | 4 + utils/torch_utils.py | 1 + utils/triton.py | 1 + 9 files changed, 508 insertions(+), 22 deletions(-) create mode 100644 CITATION.cff diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 000000000000..c277230d922f --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,14 @@ +cff-version: 1.2.0 +preferred-citation: + type: software + message: If you use YOLOv5, please cite it as below. + authors: + - family-names: Jocher + given-names: Glenn + orcid: "https://orcid.org/0000-0001-5950-6979" + title: "YOLOv5 by Ultralytics" + version: 7.0 + doi: 10.5281/zenodo.3908559 + date-released: 2020-5-29 + license: AGPL-3.0 + url: "https://github.com/ultralytics/yolov5" diff --git a/README.md b/README.md index 72420aefb155..a32acb3f3629 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,501 @@ -# HIC-YOLOv5: Improved YOLOv5 for Small Object Detection +
+

+ + + +

-## Overview +[English](README.md) | [įŽ€äŊ“中文](README.zh-CN.md) +
-This repository contains the code for HIC-YOLOv5, an improved version of YOLOv5 tailored for small object detection. The improvements are based on the paper [HIC-YOLOv5: Improved YOLOv5 For Small Object Detection](https://arxiv.org/pdf/2309.16393v1.pdf). +
+ YOLOv5 CI + YOLOv5 Citation + Docker Pulls +
+ Run on Gradient + Open In Colab + Open In Kaggle +
+
-HIC-YOLOv5 incorporates Channel Attention Block (CBAM) and Involution modules for enhanced object detection, making it suitable for both CPU and GPU training. +YOLOv5 🚀 is the world's most loved vision AI, representing Ultralytics open-source research into future vision AI methods, incorporating lessons learned and best practices evolved over thousands of hours of research and development. -## Installation +We hope that the resources here will help you get the most out of YOLOv5. Please browse the YOLOv5 Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions! -The installation process for HIC-YOLOv5 is identical to the YOLOv5 repository. You can follow the installation instructions provided in the [YOLOv5 GitHub repository](https://github.com/ultralytics/yolov5). +To request an Enterprise License please complete the form at [Ultralytics Licensing](https://ultralytics.com/license). -## Usage +
+ + + + + + + + + + + + + + + + + + + + +
-To use HIC-YOLOv5, you can specify the configuration file with the `--cfg` argument. An example command for training might look like this: +
+
+ +##
YOLOv8 🚀 NEW
+ +We are thrilled to announce the launch of Ultralytics YOLOv8 🚀, our NEW cutting-edge, state-of-the-art (SOTA) model +released at **[https://github.com/ultralytics/ultralytics](https://github.com/ultralytics/ultralytics)**. +YOLOv8 is designed to be fast, accurate, and easy to use, making it an excellent choice for a wide range of +object detection, image segmentation and image classification tasks. + +See the [YOLOv8 Docs](https://docs.ultralytics.com) for details and get started with: + +[![PyPI version](https://badge.fury.io/py/ultralytics.svg)](https://badge.fury.io/py/ultralytics) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://pepy.tech/project/ultralytics) + +```bash +pip install ultralytics +``` + +
+ + +
+ +##
Documentation
+ +See the [YOLOv5 Docs](https://docs.ultralytics.com/yolov5) for full documentation on training, testing and deployment. See below for quickstart examples. + +
+Install + +Clone repo and install [requirements.txt](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) in a +[**Python>=3.8.0**](https://www.python.org/) environment, including +[**PyTorch>=1.8**](https://pytorch.org/get-started/locally/). + +```bash +git clone https://github.com/ultralytics/yolov5 # clone +cd yolov5 +pip install -r requirements.txt # install +``` + +
+ +
+Inference + +YOLOv5 [PyTorch Hub](https://docs.ultralytics.com/yolov5/tutorials/pytorch_hub_model_loading) inference. [Models](https://github.com/ultralytics/yolov5/tree/master/models) download automatically from the latest +YOLOv5 [release](https://github.com/ultralytics/yolov5/releases). + +```python +import torch + +# Model +model = torch.hub.load("ultralytics/yolov5", "yolov5s") # or yolov5n - yolov5x6, custom + +# Images +img = "https://ultralytics.com/images/zidane.jpg" # or file, Path, PIL, OpenCV, numpy, list + +# Inference +results = model(img) + +# Results +results.print() # or .show(), .save(), .crop(), .pandas(), etc. +``` + +
+ +
+Inference with detect.py + +`detect.py` runs inference on a variety of sources, downloading [models](https://github.com/ultralytics/yolov5/tree/master/models) automatically from +the latest YOLOv5 [release](https://github.com/ultralytics/yolov5/releases) and saving results to `runs/detect`. + +```bash +python detect.py --weights yolov5s.pt --source 0 # webcam + img.jpg # image + vid.mp4 # video + screen # screenshot + path/ # directory + list.txt # list of images + list.streams # list of streams + 'path/*.jpg' # glob + 'https://youtu.be/LNwODJXcvt4' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream +``` + +
+ +
+Training + +The commands below reproduce YOLOv5 [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) +results. [Models](https://github.com/ultralytics/yolov5/tree/master/models) +and [datasets](https://github.com/ultralytics/yolov5/tree/master/data) download automatically from the latest +YOLOv5 [release](https://github.com/ultralytics/yolov5/releases). Training times for YOLOv5n/s/m/l/x are +1/2/4/6/8 days on a V100 GPU ([Multi-GPU](https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training) times faster). Use the +largest `--batch-size` possible, or pass `--batch-size -1` for +YOLOv5 [AutoBatch](https://github.com/ultralytics/yolov5/pull/5092). Batch sizes shown for V100-16GB. ```bash -python train.py --img-size 640 --batch 16 --epochs 100 --data data/coco.yaml --cfg models/yolo5m-cbam-involution.yaml +python train.py --data coco.yaml --epochs 300 --weights '' --cfg yolov5n.yaml --batch-size 128 + yolov5s 64 + yolov5m 40 + yolov5l 24 + yolov5x 16 ``` -- `--img-size`: Specifies the input image size. -- `--batch`: Sets the batch size for training. -- `--epochs`: Defines the number of training epochs. -- `--data`: Specifies the data configuration file. -- `--cfg`: Points to the configuration file for HIC-YOLOv5. In this case, it's the `models/yolo5m-cbam-involution.yaml`. + + +
+ +
+Tutorials + +- [Train Custom Data](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data) 🚀 RECOMMENDED +- [Tips for Best Training Results](https://docs.ultralytics.com/yolov5/tutorials/tips_for_best_training_results) ☘ī¸ +- [Multi-GPU Training](https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training) +- [PyTorch Hub](https://docs.ultralytics.com/yolov5/tutorials/pytorch_hub_model_loading) 🌟 NEW +- [TFLite, ONNX, CoreML, TensorRT Export](https://docs.ultralytics.com/yolov5/tutorials/model_export) 🚀 +- [NVIDIA Jetson platform Deployment](https://docs.ultralytics.com/yolov5/tutorials/running_on_jetson_nano) 🌟 NEW +- [Test-Time Augmentation (TTA)](https://docs.ultralytics.com/yolov5/tutorials/test_time_augmentation) +- [Model Ensembling](https://docs.ultralytics.com/yolov5/tutorials/model_ensembling) +- [Model Pruning/Sparsity](https://docs.ultralytics.com/yolov5/tutorials/model_pruning_and_sparsity) +- [Hyperparameter Evolution](https://docs.ultralytics.com/yolov5/tutorials/hyperparameter_evolution) +- [Transfer Learning with Frozen Layers](https://docs.ultralytics.com/yolov5/tutorials/transfer_learning_with_frozen_layers) +- [Architecture Summary](https://docs.ultralytics.com/yolov5/tutorials/architecture_description) 🌟 NEW +- [Roboflow for Datasets, Labeling, and Active Learning](https://docs.ultralytics.com/yolov5/tutorials/roboflow_datasets_integration) +- [ClearML Logging](https://docs.ultralytics.com/yolov5/tutorials/clearml_logging_integration) 🌟 NEW +- [YOLOv5 with Neural Magic's Deepsparse](https://docs.ultralytics.com/yolov5/tutorials/neural_magic_pruning_quantization) 🌟 NEW +- [Comet Logging](https://docs.ultralytics.com/yolov5/tutorials/comet_logging_integration) 🌟 NEW + +
+ +##
Integrations
+ +
+ + +
+
+ +
+ + + + + + + + + + + +
+ +| Roboflow | ClearML ⭐ NEW | Comet ⭐ NEW | Neural Magic ⭐ NEW | +| :--------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------: | +| Label and export your custom datasets directly to YOLOv5 for training with [Roboflow](https://roboflow.com/?ref=ultralytics) | Automatically track, visualize and even remotely train YOLOv5 using [ClearML](https://cutt.ly/yolov5-readme-clearml) (open-source!) | Free forever, [Comet](https://bit.ly/yolov5-readme-comet2) lets you save YOLOv5 models, resume training, and interactively visualise and debug predictions | Run YOLOv5 inference up to 6x faster with [Neural Magic DeepSparse](https://bit.ly/yolov5-neuralmagic) | + +##
Ultralytics HUB
+ +Experience seamless AI with [Ultralytics HUB](https://bit.ly/ultralytics_hub) ⭐, the all-in-one solution for data visualization, YOLOv5 and YOLOv8 🚀 model training and deployment, without any coding. Transform images into actionable insights and bring your AI visions to life with ease using our cutting-edge platform and user-friendly [Ultralytics App](https://ultralytics.com/app_install). Start your journey for **Free** now! + + + + +##
Why YOLOv5
+ +YOLOv5 has been designed to be super easy to get started and simple to learn. We prioritize real-world results. + +

+
+ YOLOv5-P5 640 Figure + +

+
+
+ Figure Notes + +- **COCO AP val** denotes mAP@0.5:0.95 metric measured on the 5000-image [COCO val2017](http://cocodataset.org) dataset over various inference sizes from 256 to 1536. +- **GPU Speed** measures average inference time per image on [COCO val2017](http://cocodataset.org) dataset using a [AWS p3.2xlarge](https://aws.amazon.com/ec2/instance-types/p3/) V100 instance at batch-size 32. +- **EfficientDet** data from [google/automl](https://github.com/google/automl) at batch size 8. +- **Reproduce** by `python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5n6.pt yolov5s6.pt yolov5m6.pt yolov5l6.pt yolov5x6.pt` + +
+ +### Pretrained Checkpoints + +| Model | size
(pixels) | mAPval
50-95 | mAPval
50 | Speed
CPU b1
(ms) | Speed
V100 b1
(ms) | Speed
V100 b32
(ms) | params
(M) | FLOPs
@640 (B) | +| ----------------------------------------------------------------------------------------------- | --------------------- | -------------------- | ----------------- | ---------------------------- | ----------------------------- | ------------------------------ | ------------------ | ---------------------- | +| [YOLOv5n](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n.pt) | 640 | 28.0 | 45.7 | **45** | **6.3** | **0.6** | **1.9** | **4.5** | +| [YOLOv5s](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s.pt) | 640 | 37.4 | 56.8 | 98 | 6.4 | 0.9 | 7.2 | 16.5 | +| [YOLOv5m](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m.pt) | 640 | 45.4 | 64.1 | 224 | 8.2 | 1.7 | 21.2 | 49.0 | +| [YOLOv5l](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5l.pt) | 640 | 49.0 | 67.3 | 430 | 10.1 | 2.7 | 46.5 | 109.1 | +| [YOLOv5x](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x.pt) | 640 | 50.7 | 68.9 | 766 | 12.1 | 4.8 | 86.7 | 205.7 | +| | | | | | | | | | +| [YOLOv5n6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n6.pt) | 1280 | 36.0 | 54.4 | 153 | 8.1 | 2.1 | 3.2 | 4.6 | +| [YOLOv5s6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s6.pt) | 1280 | 44.8 | 63.7 | 385 | 8.2 | 3.6 | 12.6 | 16.8 | +| [YOLOv5m6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m6.pt) | 1280 | 51.3 | 69.3 | 887 | 11.1 | 6.8 | 35.7 | 50.0 | +| [YOLOv5l6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5l6.pt) | 1280 | 53.7 | 71.3 | 1784 | 15.8 | 10.5 | 76.8 | 111.4 | +| [YOLOv5x6](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x6.pt)
+ [TTA] | 1280
1536 | 55.0
**55.8** | 72.7
**72.7** | 3136
- | 26.2
- | 19.4
- | 140.7
- | 209.8
- | + +
+ Table Notes + +- All checkpoints are trained to 300 epochs with default settings. Nano and Small models use [hyp.scratch-low.yaml](https://github.com/ultralytics/yolov5/blob/master/data/hyps/hyp.scratch-low.yaml) hyps, all others use [hyp.scratch-high.yaml](https://github.com/ultralytics/yolov5/blob/master/data/hyps/hyp.scratch-high.yaml). +- **mAPval** values are for single-model single-scale on [COCO val2017](http://cocodataset.org) dataset.
Reproduce by `python val.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65` +- **Speed** averaged over COCO val images using a [AWS p3.2xlarge](https://aws.amazon.com/ec2/instance-types/p3/) instance. NMS times (~1 ms/img) not included.
Reproduce by `python val.py --data coco.yaml --img 640 --task speed --batch 1` +- **TTA** [Test Time Augmentation](https://docs.ultralytics.com/yolov5/tutorials/test_time_augmentation) includes reflection and scale augmentations.
Reproduce by `python val.py --data coco.yaml --img 1536 --iou 0.7 --augment` + +
+ +##
Segmentation
+ +Our new YOLOv5 [release v7.0](https://github.com/ultralytics/yolov5/releases/v7.0) instance segmentation models are the fastest and most accurate in the world, beating all current [SOTA benchmarks](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco). We've made them super simple to train, validate and deploy. See full details in our [Release Notes](https://github.com/ultralytics/yolov5/releases/v7.0) and visit our [YOLOv5 Segmentation Colab Notebook](https://github.com/ultralytics/yolov5/blob/master/segment/tutorial.ipynb) for quickstart tutorials. + +
+ Segmentation Checkpoints + +
+ + +
+ +We trained YOLOv5 segmentations models on COCO for 300 epochs at image size 640 using A100 GPUs. We exported all models to ONNX FP32 for CPU speed tests and to TensorRT FP16 for GPU speed tests. We ran all speed tests on Google [Colab Pro](https://colab.research.google.com/signup) notebooks for easy reproducibility. + +| Model | size
(pixels) | mAPbox
50-95 | mAPmask
50-95 | Train time
300 epochs
A100 (hours) | Speed
ONNX CPU
(ms) | Speed
TRT A100
(ms) | params
(M) | FLOPs
@640 (B) | +| ------------------------------------------------------------------------------------------ | --------------------- | -------------------- | --------------------- | --------------------------------------------- | ------------------------------ | ------------------------------ | ------------------ | ---------------------- | +| [YOLOv5n-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n-seg.pt) | 640 | 27.6 | 23.4 | 80:17 | **62.7** | **1.2** | **2.0** | **7.1** | +| [YOLOv5s-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s-seg.pt) | 640 | 37.6 | 31.7 | 88:16 | 173.3 | 1.4 | 7.6 | 26.4 | +| [YOLOv5m-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m-seg.pt) | 640 | 45.0 | 37.1 | 108:36 | 427.0 | 2.2 | 22.0 | 70.8 | +| [YOLOv5l-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5l-seg.pt) | 640 | 49.0 | 39.9 | 66:43 (2x) | 857.4 | 2.9 | 47.9 | 147.7 | +| [YOLOv5x-seg](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x-seg.pt) | 640 | **50.7** | **41.4** | 62:56 (3x) | 1579.2 | 4.5 | 88.8 | 265.7 | + +- All checkpoints are trained to 300 epochs with SGD optimizer with `lr0=0.01` and `weight_decay=5e-5` at image size 640 and all default settings.
Runs logged to https://wandb.ai/glenn-jocher/YOLOv5_v70_official +- **Accuracy** values are for single-model single-scale on COCO dataset.
Reproduce by `python segment/val.py --data coco.yaml --weights yolov5s-seg.pt` +- **Speed** averaged over 100 inference images using a [Colab Pro](https://colab.research.google.com/signup) A100 High-RAM instance. Values indicate inference speed only (NMS adds about 1ms per image).
Reproduce by `python segment/val.py --data coco.yaml --weights yolov5s-seg.pt --batch 1` +- **Export** to ONNX at FP32 and TensorRT at FP16 done with `export.py`.
Reproduce by `python export.py --weights yolov5s-seg.pt --include engine --device 0 --half` + +
+ +
+ Segmentation Usage Examples  Open In Colab + +### Train + +YOLOv5 segmentation training supports auto-download COCO128-seg segmentation dataset with `--data coco128-seg.yaml` argument and manual download of COCO-segments dataset with `bash data/scripts/get_coco.sh --train --val --segments` and then `python train.py --data coco.yaml`. + +```bash +# Single-GPU +python segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 + +# Multi-GPU DDP +python -m torch.distributed.run --nproc_per_node 4 --master_port 1 segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 --device 0,1,2,3 +``` + +### Val + +Validate YOLOv5s-seg mask mAP on COCO dataset: + +```bash +bash data/scripts/get_coco.sh --val --segments # download COCO val segments split (780MB, 5000 images) +python segment/val.py --weights yolov5s-seg.pt --data coco.yaml --img 640 # validate +``` + +### Predict + +Use pretrained YOLOv5m-seg.pt to predict bus.jpg: + +```bash +python segment/predict.py --weights yolov5m-seg.pt --source data/images/bus.jpg +``` + +```python +model = torch.hub.load( + "ultralytics/yolov5", "custom", "yolov5m-seg.pt" +) # load from PyTorch Hub (WARNING: inference not yet supported) +``` + +| ![zidane](https://user-images.githubusercontent.com/26833433/203113421-decef4c4-183d-4a0a-a6c2-6435b33bc5d3.jpg) | ![bus](https://user-images.githubusercontent.com/26833433/203113416-11fe0025-69f7-4874-a0a6-65d0bfe2999a.jpg) | +| ---------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------- | + +### Export + +Export YOLOv5s-seg model to ONNX and TensorRT: + +```bash +python export.py --weights yolov5s-seg.pt --include onnx engine --img 640 --device 0 +``` + +
+ +##
Classification
+ +YOLOv5 [release v6.2](https://github.com/ultralytics/yolov5/releases) brings support for classification model training, validation and deployment! See full details in our [Release Notes](https://github.com/ultralytics/yolov5/releases/v6.2) and visit our [YOLOv5 Classification Colab Notebook](https://github.com/ultralytics/yolov5/blob/master/classify/tutorial.ipynb) for quickstart tutorials. + +
+ Classification Checkpoints + +
+ +We trained YOLOv5-cls classification models on ImageNet for 90 epochs using a 4xA100 instance, and we trained ResNet and EfficientNet models alongside with the same default training settings to compare. We exported all models to ONNX FP32 for CPU speed tests and to TensorRT FP16 for GPU speed tests. We ran all speed tests on Google [Colab Pro](https://colab.research.google.com/signup) for easy reproducibility. + +| Model | size
(pixels) | acc
top1 | acc
top5 | Training
90 epochs
4xA100 (hours) | Speed
ONNX CPU
(ms) | Speed
TensorRT V100
(ms) | params
(M) | FLOPs
@224 (B) | +| -------------------------------------------------------------------------------------------------- | --------------------- | ---------------- | ---------------- | -------------------------------------------- | ------------------------------ | ----------------------------------- | ------------------ | ---------------------- | +| [YOLOv5n-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5n-cls.pt) | 224 | 64.6 | 85.4 | 7:59 | **3.3** | **0.5** | **2.5** | **0.5** | +| [YOLOv5s-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s-cls.pt) | 224 | 71.5 | 90.2 | 8:09 | 6.6 | 0.6 | 5.4 | 1.4 | +| [YOLOv5m-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5m-cls.pt) | 224 | 75.9 | 92.9 | 10:06 | 15.5 | 0.9 | 12.9 | 3.9 | +| [YOLOv5l-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5l-cls.pt) | 224 | 78.0 | 94.0 | 11:56 | 26.9 | 1.4 | 26.5 | 8.5 | +| [YOLOv5x-cls](https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x-cls.pt) | 224 | **79.0** | **94.4** | 15:04 | 54.3 | 1.8 | 48.1 | 15.9 | +| | | | | | | | | | +| [ResNet18](https://github.com/ultralytics/yolov5/releases/download/v7.0/resnet18.pt) | 224 | 70.3 | 89.5 | **6:47** | 11.2 | 0.5 | 11.7 | 3.7 | +| [ResNet34](https://github.com/ultralytics/yolov5/releases/download/v7.0/resnet34.pt) | 224 | 73.9 | 91.8 | 8:33 | 20.6 | 0.9 | 21.8 | 7.4 | +| [ResNet50](https://github.com/ultralytics/yolov5/releases/download/v7.0/resnet50.pt) | 224 | 76.8 | 93.4 | 11:10 | 23.4 | 1.0 | 25.6 | 8.5 | +| [ResNet101](https://github.com/ultralytics/yolov5/releases/download/v7.0/resnet101.pt) | 224 | 78.5 | 94.3 | 17:10 | 42.1 | 1.9 | 44.5 | 15.9 | +| | | | | | | | | | +| [EfficientNet_b0](https://github.com/ultralytics/yolov5/releases/download/v7.0/efficientnet_b0.pt) | 224 | 75.1 | 92.4 | 13:03 | 12.5 | 1.3 | 5.3 | 1.0 | +| [EfficientNet_b1](https://github.com/ultralytics/yolov5/releases/download/v7.0/efficientnet_b1.pt) | 224 | 76.4 | 93.2 | 17:04 | 14.9 | 1.6 | 7.8 | 1.5 | +| [EfficientNet_b2](https://github.com/ultralytics/yolov5/releases/download/v7.0/efficientnet_b2.pt) | 224 | 76.6 | 93.4 | 17:10 | 15.9 | 1.6 | 9.1 | 1.7 | +| [EfficientNet_b3](https://github.com/ultralytics/yolov5/releases/download/v7.0/efficientnet_b3.pt) | 224 | 77.7 | 94.0 | 19:19 | 18.9 | 1.9 | 12.2 | 2.4 | + +
+ Table Notes (click to expand) + +- All checkpoints are trained to 90 epochs with SGD optimizer with `lr0=0.001` and `weight_decay=5e-5` at image size 224 and all default settings.
Runs logged to https://wandb.ai/glenn-jocher/YOLOv5-Classifier-v6-2 +- **Accuracy** values are for single-model single-scale on [ImageNet-1k](https://www.image-net.org/index.php) dataset.
Reproduce by `python classify/val.py --data ../datasets/imagenet --img 224` +- **Speed** averaged over 100 inference images using a Google [Colab Pro](https://colab.research.google.com/signup) V100 High-RAM instance.
Reproduce by `python classify/val.py --data ../datasets/imagenet --img 224 --batch 1` +- **Export** to ONNX at FP32 and TensorRT at FP16 done with `export.py`.
Reproduce by `python export.py --weights yolov5s-cls.pt --include engine onnx --imgsz 224` + +
+
+ +
+ Classification Usage Examples  Open In Colab + +### Train + +YOLOv5 classification training supports auto-download of MNIST, Fashion-MNIST, CIFAR10, CIFAR100, Imagenette, Imagewoof, and ImageNet datasets with the `--data` argument. To start training on MNIST for example use `--data mnist`. + +```bash +# Single-GPU +python classify/train.py --model yolov5s-cls.pt --data cifar100 --epochs 5 --img 224 --batch 128 + +# Multi-GPU DDP +python -m torch.distributed.run --nproc_per_node 4 --master_port 1 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3 +``` + +### Val + +Validate YOLOv5m-cls accuracy on ImageNet-1k dataset: + +```bash +bash data/scripts/get_imagenet.sh --val # download ImageNet val split (6.3G, 50000 images) +python classify/val.py --weights yolov5m-cls.pt --data ../datasets/imagenet --img 224 # validate +``` + +### Predict + +Use pretrained YOLOv5s-cls.pt to predict bus.jpg: + +```bash +python classify/predict.py --weights yolov5s-cls.pt --source data/images/bus.jpg +``` + +```python +model = torch.hub.load( + "ultralytics/yolov5", "custom", "yolov5s-cls.pt" +) # load from PyTorch Hub +``` + +### Export + +Export a group of trained YOLOv5s-cls, ResNet and EfficientNet models to ONNX and TensorRT: + +```bash +python export.py --weights yolov5s-cls.pt resnet50.pt efficientnet_b0.pt --include onnx engine --img 224 +``` + +
+ +##
Environments
+ +Get started in seconds with our verified environments. Click each icon below for details. + +
+ + + + + + + + + + + + + + + + + +
+ +##
Contribute
+ +We love your input! We want to make contributing to YOLOv5 as easy and transparent as possible. Please see our [Contributing Guide](https://docs.ultralytics.com/help/contributing/) to get started, and fill out the [YOLOv5 Survey](https://ultralytics.com/survey?utm_source=github&utm_medium=social&utm_campaign=Survey) to send us feedback on your experiences. Thank you to all our contributors! + + -## Testing for Multi-GPU Training (TODO) + + -I am actively working on adding support for multi-GPU training. Please stay tuned for updates on testing and training with multiple GPUs. +##
License
-## Acknowledgments +Ultralytics offers two licensing options to accommodate diverse use cases: -I want to express our gratitude to the authors of the paper "HIC-YOLOv5: Improved YOLOv5 For Small Object Detection" for their contributions, which inspired the development of HIC-YOLOv5. +- **AGPL-3.0 License**: This [OSI-approved](https://opensource.org/licenses/) open-source license is ideal for students and enthusiasts, promoting open collaboration and knowledge sharing. See the [LICENSE](https://github.com/ultralytics/yolov5/blob/master/LICENSE) file for more details. +- **Enterprise License**: Designed for commercial use, this license permits seamless integration of Ultralytics software and AI models into commercial goods and services, bypassing the open-source requirements of AGPL-3.0. If your scenario involves embedding our solutions into a commercial offering, reach out through [Ultralytics Licensing](https://ultralytics.com/license). -## License +##
Contact
-HIC-YOLOv5 is released under the MIT License. Please refer to the LICENSE file for more details. +For YOLOv5 bug reports and feature requests please visit [GitHub Issues](https://github.com/ultralytics/yolov5/issues), and join our [Discord](https://ultralytics.com/discord) community for questions and discussions! -For additional information and updates, please refer to the [YOLOv5 GitHub repository](https://github.com/ultralytics/yolov5). +
+
+ + + + + + + + + + + + + + + + + + + + +
-**Note:** Be sure to refer to the official [YOLOv5 repository](https://github.com/ultralytics/yolov5) for the latest updates and documentation. +[tta]: https://docs.ultralytics.com/yolov5/tutorials/test_time_augmentation diff --git a/export.py b/export.py index 9ef70b255059..71e4eb94d1c4 100644 --- a/export.py +++ b/export.py @@ -78,6 +78,7 @@ class iOSModel(torch.nn.Module): + def __init__(self, model, im): super().__init__() b, c, h, w = im.shape # batch, channel, height, width diff --git a/models/tf.py b/models/tf.py index d361558020ae..62ba3ebf0782 100644 --- a/models/tf.py +++ b/models/tf.py @@ -340,6 +340,7 @@ def call(self, x): class TFProto(keras.layers.Layer): + def __init__(self, c1, c_=256, c2=32, w=None): super().__init__() self.cv1 = TFConv(c1, c_, k=3, w=w.cv1) diff --git a/utils/activations.py b/utils/activations.py index 40fdc4603a3c..e4d4bbde5ec8 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -33,6 +33,7 @@ def forward(x): class MemoryEfficientMish(nn.Module): # Mish activation memory-efficient class F(torch.autograd.Function): + @staticmethod def forward(ctx, x): ctx.save_for_backward(x) @@ -65,6 +66,7 @@ class AconC(nn.Module): AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter according to "Activate or Not: Learning Customized Activation" . """ + def __init__(self, c1): super().__init__() self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) @@ -81,6 +83,7 @@ class MetaAconC(nn.Module): MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network according to "Activate or Not: Learning Customized Activation" . """ + def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r super().__init__() c2 = max(r, c1 // r) diff --git a/utils/callbacks.py b/utils/callbacks.py index b5a6c1c096fe..c90fa824cdb4 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -10,6 +10,7 @@ class Callbacks: """" Handles all registered callbacks for YOLOv5 Hooks """ + def __init__(self): # Define the available callbacks self._callbacks = { diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 472b372b86f6..1fbd0361ded4 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -158,6 +158,7 @@ class InfiniteDataLoader(dataloader.DataLoader): Uses same syntax as vanilla DataLoader """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) @@ -177,6 +178,7 @@ class _RepeatSampler: Args: sampler (Sampler) """ + def __init__(self, sampler): self.sampler = sampler @@ -1052,6 +1054,7 @@ class HUBDatasetStats(): stats.get_json(save=False) stats.process_images() """ + def __init__(self, path='coco128.yaml', autodownload=False): # Initialize class zipped, data_dir, yaml_path = self._unzip(Path(path)) @@ -1166,6 +1169,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): transform: torchvision transforms, used by default album_transform: Albumentations transforms, used if installed """ + def __init__(self, root, augment, imgsz, cache=False): super().__init__(root=root) self.torch_transforms = classify_transforms(imgsz) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index c0608ebe519d..13a356f3238c 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -406,6 +406,7 @@ class ModelEMA: Keeps a moving average of everything in the model state_dict (parameters and buffers) For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage """ + def __init__(self, model, decay=0.9999, tau=2000, updates=0): # Create EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA diff --git a/utils/triton.py b/utils/triton.py index 57a07824cf2a..b5153dad940d 100644 --- a/utils/triton.py +++ b/utils/triton.py @@ -13,6 +13,7 @@ class TritonRemoteModel: be configured to communicate over GRPC or HTTP. It accepts Torch Tensors as input and returns them as outputs. """ + def __init__(self, url: str): """ Keyword arguments: From 8738c27b5e9ea2be8a74830511c49704040a11e2 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 03:41:19 +0530 Subject: [PATCH 19/21] yapf reformat --- models/common.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index 55b29d55e6c0..521551f273d0 100644 --- a/models/common.py +++ b/models/common.py @@ -365,8 +365,9 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, model.half() if fp16 else model.float() if extra_files['config.txt']: # load metadata dict d = json.loads(extra_files['config.txt'], - object_hook=lambda d: {int(k) if k.isdigit() else k: v - for k, v in d.items()}) + object_hook=lambda d: { + int(k) if k.isdigit() else k: v + for k, v in d.items()}) stride, names = int(d['stride']), d['names'] elif dnn: # ONNX OpenCV DNN LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...') @@ -884,6 +885,7 @@ def forward(self, x): # contributed by @aash1999 class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=16): """ Initialize the Channel Attention module. @@ -920,6 +922,7 @@ def forward(self, x): # contributed by @aash1999 class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): """ Initialize the Spatial Attention module. @@ -996,6 +999,7 @@ def forward(self, x): # contributed by @aash1999 class Involution(nn.Module): + def __init__(self, c1, c2, kernel_size, stride): """ Initialize the Involution module. From 0fd8fe3ddc46d53fe31a7e6b65b9daf153b774fe Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 03:50:43 +0530 Subject: [PATCH 20/21] movig files to where they belong --- data/hyps/{cbam.hyp.yaml => hyp.hic-yolov5s.yaml} | 5 ++--- models/{ => hub}/yolov5s-cbam-involution.yaml | 0 2 files changed, 2 insertions(+), 3 deletions(-) rename data/hyps/{cbam.hyp.yaml => hyp.hic-yolov5s.yaml} (84%) rename models/{ => hub}/yolov5s-cbam-involution.yaml (100%) diff --git a/data/hyps/cbam.hyp.yaml b/data/hyps/hyp.hic-yolov5s.yaml similarity index 84% rename from data/hyps/cbam.hyp.yaml rename to data/hyps/hyp.hic-yolov5s.yaml index f46921dc66e9..66da95727e04 100644 --- a/data/hyps/cbam.hyp.yaml +++ b/data/hyps/hyp.hic-yolov5s.yaml @@ -1,7 +1,6 @@ # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license -# Hyperparameters for low-augmentation COCO training from scratch -# python train.py --batch 64 --cfg yolov5n6.yaml --weights '' --data coco.yaml --img 640 --epochs 300 --linear -# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials +# hyperparameters for HIC-YOLOv5 for small object detection on VisDrone Dataset +# python train.py --hyp hyp.hyp.hic-yolov5s lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) diff --git a/models/yolov5s-cbam-involution.yaml b/models/hub/yolov5s-cbam-involution.yaml similarity index 100% rename from models/yolov5s-cbam-involution.yaml rename to models/hub/yolov5s-cbam-involution.yaml From 2fc73ca6e6aebeb3f30c0b11d86f66da0858842d Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 26 Oct 2023 03:51:12 +0530 Subject: [PATCH 21/21] typo correction --- data/hyps/hyp.hic-yolov5s.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/hyps/hyp.hic-yolov5s.yaml b/data/hyps/hyp.hic-yolov5s.yaml index 66da95727e04..80f6e6dd95c4 100644 --- a/data/hyps/hyp.hic-yolov5s.yaml +++ b/data/hyps/hyp.hic-yolov5s.yaml @@ -1,6 +1,6 @@ # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license # hyperparameters for HIC-YOLOv5 for small object detection on VisDrone Dataset -# python train.py --hyp hyp.hyp.hic-yolov5s +# python train.py --hyp hyp.hic-yolov5s.yaml lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)