diff --git a/.github/README_cn.md b/.github/README_cn.md index 7b56462905d7..86b502df61f7 100644 --- a/.github/README_cn.md +++ b/.github/README_cn.md @@ -171,26 +171,23 @@ python train.py --data coco.yaml --cfg yolov5n.yaml --weights '' --batch-size 12 ##
如何与第三方集成
- - - - - - + + + + + + + + + + +
-|Weights and Biases|Roboflow ⭐ 新| -|:-:|:-:| -|通过 [Weights & Biases](https://wandb.ai/site?utm_campaign=repo_yolo_readme) 自动跟踪和可视化你在云端的所有YOLOv5训练运行状态。|标记并将您的自定义数据集直接导出到YOLOv5,以便用 [Roboflow](https://roboflow.com/?ref=ultralytics) 进行训练。 | - - ##
为什么选择 YOLOv5
@@ -239,6 +236,84 @@ We are super excited about our first-ever Ultralytics YOLOv5 🚀 EXPORT Competi + +##
Classification ⭐ NEW
+ +YOLOv5 [release v6.2](https://github.com/ultralytics/yolov5/releases) brings support for classification model training, validation, prediction and export! We've made training classifier models super simple. Click below to get started. + +
+ Classification Checkpoints (click to expand) + +
+ +We trained YOLOv5-cls classification models on ImageNet for 90 epochs using a 4xA100 instance, and we trained ResNet and EfficientNet models alongside with the same default training settings to compare. We exported all models to ONNX FP32 for CPU speed tests and to TensorRT FP16 for GPU speed tests. We ran all speed tests on Google [Colab Pro](https://colab.research.google.com/signup) for easy reproducibility. + +| Model | size
(pixels) | acc
top1 | acc
top5 | Training
90 epochs
4xA100 (hours) | Speed
ONNX CPU
(ms) | Speed
TensorRT V100
(ms) | params
(M) | FLOPs
@224 (B) | +|----------------------------------------------------------------------------------------------------|-----------------------|------------------|------------------|----------------------------------------------|--------------------------------|-------------------------------------|--------------------|------------------------| +| [YOLOv5n-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5n-cls.pt) | 224 | 64.6 | 85.4 | 7:59 | **3.3** | **0.5** | **2.5** | **0.5** | +| [YOLOv5s-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5s-cls.pt) | 224 | 71.5 | 90.2 | 8:09 | 6.6 | 0.6 | 5.4 | 1.4 | +| [YOLOv5m-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5m-cls.pt) | 224 | 75.9 | 92.9 | 10:06 | 15.5 | 0.9 | 12.9 | 3.9 | +| [YOLOv5l-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5l-cls.pt) | 224 | 78.0 | 94.0 | 11:56 | 26.9 | 1.4 | 26.5 | 8.5 | +| [YOLOv5x-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5x-cls.pt) | 224 | **79.0** | **94.4** | 15:04 | 54.3 | 1.8 | 48.1 | 15.9 | +| | +| [ResNet18](https://github.com/ultralytics/yolov5/releases/download/v6.2/resnet18.pt) | 224 | 70.3 | 89.5 | **6:47** | 11.2 | 0.5 | 11.7 | 3.7 | +| [ResNet34](https://github.com/ultralytics/yolov5/releases/download/v6.2/resnet34.pt) | 224 | 73.9 | 91.8 | 8:33 | 20.6 | 0.9 | 21.8 | 7.4 | +| [ResNet50](https://github.com/ultralytics/yolov5/releases/download/v6.2/resnet50.pt) | 224 | 76.8 | 93.4 | 11:10 | 23.4 | 1.0 | 25.6 | 8.5 | +| [ResNet101](https://github.com/ultralytics/yolov5/releases/download/v6.2/resnet101.pt) | 224 | 78.5 | 94.3 | 17:10 | 42.1 | 1.9 | 44.5 | 15.9 | +| | +| [EfficientNet_b0](https://github.com/ultralytics/yolov5/releases/download/v6.2/efficientnet_b0.pt) | 224 | 75.1 | 92.4 | 13:03 | 12.5 | 1.3 | 5.3 | 1.0 | +| [EfficientNet_b1](https://github.com/ultralytics/yolov5/releases/download/v6.2/efficientnet_b1.pt) | 224 | 76.4 | 93.2 | 17:04 | 14.9 | 1.6 | 7.8 | 1.5 | +| [EfficientNet_b2](https://github.com/ultralytics/yolov5/releases/download/v6.2/efficientnet_b2.pt) | 224 | 76.6 | 93.4 | 17:10 | 15.9 | 1.6 | 9.1 | 1.7 | +| [EfficientNet_b3](https://github.com/ultralytics/yolov5/releases/download/v6.2/efficientnet_b3.pt) | 224 | 77.7 | 94.0 | 19:19 | 18.9 | 1.9 | 12.2 | 2.4 | + +
+ Table Notes (click to expand) + +- All checkpoints are trained to 90 epochs with SGD optimizer with lr0=0.001 at image size 224 and all default settings. Runs logged to https://wandb.ai/glenn-jocher/YOLOv5-Classifier-v6-2. +- **Accuracy** values are for single-model single-scale on [ImageNet-1k](https://www.image-net.org/index.php) dataset.
Reproduce by `python classify/val.py --data ../datasets/imagenet --img 224` +- **Speed** averaged over 100 inference images using a Google [Colab Pro](https://colab.research.google.com/signup) V100 High-RAM instance.
Reproduce by `python classify/val.py --data ../datasets/imagenet --img 224 --batch 1` +- **Export** to ONNX at FP32 and TensorRT at FP16 done with `export.py`.
Reproduce by `python export.py --weights yolov5s-cls.pt --include engine onnx --imgsz 224` +
+
+ +
+ Classification Usage Examples (click to expand) + +### Train +YOLOv5 classification training supports auto-download of MNIST, Fashion-MNIST, CIFAR10, CIFAR100, Imagenette, Imagewoof, and ImageNet datasets with the `--data` argument. To start training on MNIST for example use `--data mnist`. + +```bash +# Single-GPU +python classify/train.py --model yolov5s-cls.pt --data cifar100 --epochs 5 --img 224 --batch 128 + +# Multi-GPU DDP +python -m torch.distributed.run --nproc_per_node 4 --master_port 1 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3 +``` + +### Val +Validate accuracy on a pretrained model. To validate YOLOv5s-cls accuracy on ImageNet. +```bash +bash data/scripts/get_imagenet.sh --val # download ImageNet val split (6.3G, 50000 images) +python classify/val.py --weights yolov5s-cls.pt --data ../datasets/imagenet --img 224 +``` + +### Predict +Run a classification prediction on an image. +```bash +python classify/predict.py --weights yolov5s-cls.pt --data data/images/bus.jpg +``` +```python +model = torch.hub.load('ultralytics/yolov5', 'custom', 'yolov5s-cls.pt') # load from PyTorch Hub +``` + +### Export +Export a group of trained YOLOv5-cls, ResNet and EfficientNet models to ONNX and TensorRT. +```bash +python export.py --weights yolov5s-cls.pt resnet50.pt efficientnet_b0.pt --include onnx engine --img 224 +``` +
+ + ##
贡献
我们重视您的意见! 我们希望给大家提供尽可能的简单和透明的方式对 YOLOv5 做出贡献。开始之前请先点击并查看我们的 [贡献指南](CONTRIBUTING.md),填写[YOLOv5调查问卷](https://ultralytics.com/survey?utm_source=github&utm_medium=social&utm_campaign=Survey) 来向我们发送您的经验反馈。真诚感谢我们所有的贡献者! diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 31d38ead530f..aa797c44d487 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -5,9 +5,9 @@ name: YOLOv5 CI on: push: - branches: [master] + branches: [ master ] pull_request: - branches: [master] + branches: [ master ] schedule: - cron: '0 0 * * *' # runs at 00:00 UTC every day @@ -16,9 +16,9 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest] - python-version: ['3.9'] # requires python<=3.9 - model: [yolov5n] + os: [ ubuntu-latest ] + python-version: [ '3.9' ] # requires python<=3.9 + model: [ yolov5n ] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 @@ -47,9 +47,9 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.10'] - model: [yolov5n] + os: [ ubuntu-latest, macos-latest, windows-latest ] + python-version: [ '3.10' ] + model: [ yolov5n ] include: - os: ubuntu-latest python-version: '3.7' # '3.6.8' min @@ -87,7 +87,7 @@ jobs: else pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu fi - shell: bash # required for Windows compatibility + shell: bash # for Windows compatibility - name: Check environment run: | python -c "import utils; utils.notebook_init()" @@ -100,8 +100,8 @@ jobs: python --version pip --version pip list - - name: Run tests - shell: bash + - name: Test detection + shell: bash # for Windows compatibility run: | # export PYTHONPATH="$PWD" # to run '$ python *.py' files in subdirectories m=${{ matrix.model }} # official weights @@ -123,3 +123,13 @@ jobs: model = torch.hub.load('.', 'custom', path=path, source='local') print(model('data/images/bus.jpg')) EOF + - name: Test classification + shell: bash # for Windows compatibility + run: | + m=${{ matrix.model }}-cls.pt # official weights + b=runs/train-cls/exp/weights/best.pt # best.pt checkpoint + python classify/train.py --imgsz 32 --model $m --data mnist2560 --epochs 1 # train + python classify/val.py --imgsz 32 --weights $b --data ../datasets/mnist2560 # val + python classify/predict.py --imgsz 32 --weights $b --source ../datasets/mnist2560/test/7/60.png # predict + python classify/predict.py --imgsz 32 --weights $m --source data/images/bus.jpg # predict + python export.py --weights $b --img 64 --imgsz 224 --include torchscript # export diff --git a/README.md b/README.md index 62c7ed4f53e6..b368d1d6e264 100644 --- a/README.md +++ b/README.md @@ -201,14 +201,6 @@ Get started in seconds with our verified environments. Click each icon below for |:-:|:-:|:-:|:-:| |Automatically compile and quantize YOLOv5 for better inference performance in one click at [Deci](https://bit.ly/yolov5-deci-platform)|Automatically track, visualize and even remotely train YOLOv5 using [ClearML](https://cutt.ly/yolov5-readme-clearml) (open-source!)|Label and export your custom datasets directly to YOLOv5 for training with [Roboflow](https://roboflow.com/?ref=ultralytics) |Automatically track and visualize all your YOLOv5 training runs in the cloud with [Weights & Biases](https://wandb.ai/site?utm_campaign=repo_yolo_readme) - ##
Why YOLOv5
@@ -254,6 +246,83 @@ We are super excited about our first-ever Ultralytics YOLOv5 🚀 EXPORT Competi +##
Classification ⭐ NEW
+ +YOLOv5 [release v6.2](https://github.com/ultralytics/yolov5/releases) brings support for classification model training, validation, prediction and export! We've made training classifier models super simple. Click below to get started. + +
+ Classification Checkpoints (click to expand) + +
+ +We trained YOLOv5-cls classification models on ImageNet for 90 epochs using a 4xA100 instance, and we trained ResNet and EfficientNet models alongside with the same default training settings to compare. We exported all models to ONNX FP32 for CPU speed tests and to TensorRT FP16 for GPU speed tests. We ran all speed tests on Google [Colab Pro](https://colab.research.google.com/signup) for easy reproducibility. + +| Model | size
(pixels) | acc
top1 | acc
top5 | Training
90 epochs
4xA100 (hours) | Speed
ONNX CPU
(ms) | Speed
TensorRT V100
(ms) | params
(M) | FLOPs
@224 (B) | +|----------------------------------------------------------------------------------------------------|-----------------------|------------------|------------------|----------------------------------------------|--------------------------------|-------------------------------------|--------------------|------------------------| +| [YOLOv5n-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5n-cls.pt) | 224 | 64.6 | 85.4 | 7:59 | **3.3** | **0.5** | **2.5** | **0.5** | +| [YOLOv5s-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5s-cls.pt) | 224 | 71.5 | 90.2 | 8:09 | 6.6 | 0.6 | 5.4 | 1.4 | +| [YOLOv5m-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5m-cls.pt) | 224 | 75.9 | 92.9 | 10:06 | 15.5 | 0.9 | 12.9 | 3.9 | +| [YOLOv5l-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5l-cls.pt) | 224 | 78.0 | 94.0 | 11:56 | 26.9 | 1.4 | 26.5 | 8.5 | +| [YOLOv5x-cls](https://github.com/ultralytics/yolov5/releases/download/v6.2/yolov5x-cls.pt) | 224 | **79.0** | **94.4** | 15:04 | 54.3 | 1.8 | 48.1 | 15.9 | +| | +| [ResNet18](https://github.com/ultralytics/yolov5/releases/download/v6.2/resnet18.pt) | 224 | 70.3 | 89.5 | **6:47** | 11.2 | 0.5 | 11.7 | 3.7 | +| [ResNet34](https://github.com/ultralytics/yolov5/releases/download/v6.2/resnet34.pt) | 224 | 73.9 | 91.8 | 8:33 | 20.6 | 0.9 | 21.8 | 7.4 | +| [ResNet50](https://github.com/ultralytics/yolov5/releases/download/v6.2/resnet50.pt) | 224 | 76.8 | 93.4 | 11:10 | 23.4 | 1.0 | 25.6 | 8.5 | +| [ResNet101](https://github.com/ultralytics/yolov5/releases/download/v6.2/resnet101.pt) | 224 | 78.5 | 94.3 | 17:10 | 42.1 | 1.9 | 44.5 | 15.9 | +| | +| [EfficientNet_b0](https://github.com/ultralytics/yolov5/releases/download/v6.2/efficientnet_b0.pt) | 224 | 75.1 | 92.4 | 13:03 | 12.5 | 1.3 | 5.3 | 1.0 | +| [EfficientNet_b1](https://github.com/ultralytics/yolov5/releases/download/v6.2/efficientnet_b1.pt) | 224 | 76.4 | 93.2 | 17:04 | 14.9 | 1.6 | 7.8 | 1.5 | +| [EfficientNet_b2](https://github.com/ultralytics/yolov5/releases/download/v6.2/efficientnet_b2.pt) | 224 | 76.6 | 93.4 | 17:10 | 15.9 | 1.6 | 9.1 | 1.7 | +| [EfficientNet_b3](https://github.com/ultralytics/yolov5/releases/download/v6.2/efficientnet_b3.pt) | 224 | 77.7 | 94.0 | 19:19 | 18.9 | 1.9 | 12.2 | 2.4 | + +
+ Table Notes (click to expand) + +- All checkpoints are trained to 90 epochs with SGD optimizer with lr0=0.001 at image size 224 and all default settings. Runs logged to https://wandb.ai/glenn-jocher/YOLOv5-Classifier-v6-2. +- **Accuracy** values are for single-model single-scale on [ImageNet-1k](https://www.image-net.org/index.php) dataset.
Reproduce by `python classify/val.py --data ../datasets/imagenet --img 224` +- **Speed** averaged over 100 inference images using a Google [Colab Pro](https://colab.research.google.com/signup) V100 High-RAM instance.
Reproduce by `python classify/val.py --data ../datasets/imagenet --img 224 --batch 1` +- **Export** to ONNX at FP32 and TensorRT at FP16 done with `export.py`.
Reproduce by `python export.py --weights yolov5s-cls.pt --include engine onnx --imgsz 224` +
+
+ +
+ Classification Usage Examples (click to expand) + +### Train +YOLOv5 classification training supports auto-download of MNIST, Fashion-MNIST, CIFAR10, CIFAR100, Imagenette, Imagewoof, and ImageNet datasets with the `--data` argument. To start training on MNIST for example use `--data mnist`. + +```bash +# Single-GPU +python classify/train.py --model yolov5s-cls.pt --data cifar100 --epochs 5 --img 224 --batch 128 + +# Multi-GPU DDP +python -m torch.distributed.run --nproc_per_node 4 --master_port 1 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3 +``` + +### Val +Validate accuracy on a pretrained model. To validate YOLOv5s-cls accuracy on ImageNet. +```bash +bash data/scripts/get_imagenet.sh --val # download ImageNet val split (6.3G, 50000 images) +python classify/val.py --weights yolov5s-cls.pt --data ../datasets/imagenet --img 224 +``` + +### Predict +Run a classification prediction on an image. +```bash +python classify/predict.py --weights yolov5s-cls.pt --data data/images/bus.jpg +``` +```python +model = torch.hub.load('ultralytics/yolov5', 'custom', 'yolov5s-cls.pt') # load from PyTorch Hub +``` + +### Export +Export a group of trained YOLOv5-cls, ResNet and EfficientNet models to ONNX and TensorRT. +```bash +python export.py --weights yolov5s-cls.pt resnet50.pt efficientnet_b0.pt --include onnx engine --img 224 +``` +
+ + ##
Contribute
We love your input! We want to make contributing to YOLOv5 as easy and transparent as possible. Please see our [Contributing Guide](CONTRIBUTING.md) to get started, and fill out the [YOLOv5 Survey](https://ultralytics.com/survey?utm_source=github&utm_medium=social&utm_campaign=Survey) to send us feedback on your experiences. Thank you to all our contributors! diff --git a/classify/predict.py b/classify/predict.py new file mode 100644 index 000000000000..419830d43952 --- /dev/null +++ b/classify/predict.py @@ -0,0 +1,109 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Run classification inference on images + +Usage: + $ python classify/predict.py --weights yolov5s-cls.pt --source im.jpg +""" + +import argparse +import os +import sys +from pathlib import Path + +import cv2 +import torch.nn.functional as F + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +from classify.train import imshow_cls +from models.common import DetectMultiBackend +from utils.augmentations import classify_transforms +from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args +from utils.torch_utils import select_device, smart_inference_mode, time_sync + + +@smart_inference_mode() +def run( + weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s) + source=ROOT / 'data/images/bus.jpg', # file/dir/URL/glob, 0 for webcam + imgsz=224, # inference size + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + half=False, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + show=True, + project=ROOT / 'runs/predict-cls', # save to project/name + name='exp', # save to project/name + exist_ok=False, # existing project/name ok, do not increment +): + file = str(source) + seen, dt = 1, [0.0, 0.0, 0.0] + device = select_device(device) + + # Directories + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run + save_dir.mkdir(parents=True, exist_ok=True) # make dir + + # Transforms + transforms = classify_transforms(imgsz) + + # Load model + model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half) + model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup + + # Image + t1 = time_sync() + im = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB) + im = transforms(im).unsqueeze(0).to(device) + im = im.half() if model.fp16 else im.float() + t2 = time_sync() + dt[0] += t2 - t1 + + # Inference + results = model(im) + t3 = time_sync() + dt[1] += t3 - t2 + + p = F.softmax(results, dim=1) # probabilities + i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices + dt[2] += time_sync() - t3 + LOGGER.info(f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}") + + # Print results + t = tuple(x / seen * 1E3 for x in dt) # speeds per image + shape = (1, 3, imgsz, imgsz) + LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t) + if show: + imshow_cls(im, f=save_dir / Path(file).name, verbose=True) + LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") + return p + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)') + parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', help='file') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') + parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + parser.add_argument('--project', default=ROOT / 'runs/predict-cls', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + opt = parser.parse_args() + print_args(vars(opt)) + return opt + + +def main(opt): + check_requirements(exclude=('tensorboard', 'thop')) + run(**vars(opt)) + + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/classify/train.py b/classify/train.py new file mode 100644 index 000000000000..f2b465567446 --- /dev/null +++ b/classify/train.py @@ -0,0 +1,325 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Train a YOLOv5 classifier model on a classification dataset +Datasets: --data mnist, fashion-mnist, cifar10, cifar100, imagenette, imagewoof, imagenet, or 'path/to/custom/dataset' + +Usage: + $ python classify/train.py --model yolov5s-cls.pt --data cifar100 --epochs 5 --img 128 + $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3 +""" + +import argparse +import os +import subprocess +import sys +import time +from copy import deepcopy +from datetime import datetime +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.hub as hub +import torch.optim.lr_scheduler as lr_scheduler +import torchvision +from torch.cuda import amp +from tqdm import tqdm + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +from classify import val as validate +from models.experimental import attempt_load +from models.yolo import ClassificationModel, DetectionModel +from utils.dataloaders import create_classification_dataloader +from utils.general import (DATASETS_DIR, LOGGER, WorkingDirectory, check_git_status, check_requirements, colorstr, + download, increment_path, init_seeds, print_args, yaml_save) +from utils.loggers import GenericLogger +from utils.plots import imshow_cls +from utils.torch_utils import (ModelEMA, model_info, reshape_classifier_output, select_device, smart_DDP, + smart_optimizer, smartCrossEntropyLoss, torch_distributed_zero_first) + +LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html +RANK = int(os.getenv('RANK', -1)) +WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) + + +def train(opt, device): + init_seeds(opt.seed + 1 + RANK, deterministic=True) + save_dir, data, bs, epochs, nw, imgsz, pretrained = \ + opt.save_dir, Path(opt.data), opt.batch_size, opt.epochs, min(os.cpu_count() - 1, opt.workers), \ + opt.imgsz, str(opt.pretrained).lower() == 'true' + cuda = device.type != 'cpu' + + # Directories + wdir = save_dir / 'weights' + wdir.mkdir(parents=True, exist_ok=True) # make dir + last, best = wdir / 'last.pt', wdir / 'best.pt' + + # Save run settings + yaml_save(save_dir / 'opt.yaml', vars(opt)) + + # Logger + logger = GenericLogger(opt=opt, console_logger=LOGGER) if RANK in {-1, 0} else None + + # Download Dataset + with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT): + data_dir = data if data.is_dir() else (DATASETS_DIR / data) + if not data_dir.is_dir(): + LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') + t = time.time() + if str(data) == 'imagenet': + subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) + else: + url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{data}.zip' + download(url, dir=data_dir.parent) + s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" + LOGGER.info(s) + + # Dataloaders + nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes + trainloader = create_classification_dataloader(path=data_dir / 'train', + imgsz=imgsz, + batch_size=bs // WORLD_SIZE, + augment=True, + cache=opt.cache, + rank=LOCAL_RANK, + workers=nw) + + test_dir = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val + if RANK in {-1, 0}: + testloader = create_classification_dataloader(path=test_dir, + imgsz=imgsz, + batch_size=bs // WORLD_SIZE * 2, + augment=False, + cache=opt.cache, + rank=-1, + workers=nw) + + # Model + with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT): + if Path(opt.model).is_file() or opt.model.endswith('.pt'): + model = attempt_load(opt.model, device='cpu', fuse=False) + elif opt.model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0 + model = torchvision.models.__dict__[opt.model](weights='IMAGENET1K_V1' if pretrained else None) + else: + m = hub.list('ultralytics/yolov5') # + hub.list('pytorch/vision') # models + raise ModuleNotFoundError(f'--model {opt.model} not found. Available models are: \n' + '\n'.join(m)) + if isinstance(model, DetectionModel): + LOGGER.warning("WARNING: pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'") + model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model + reshape_classifier_output(model, nc) # update class count + for p in model.parameters(): + p.requires_grad = True # for training + for m in model.modules(): + if not pretrained and hasattr(m, 'reset_parameters'): + m.reset_parameters() + if isinstance(m, torch.nn.Dropout) and opt.dropout is not None: + m.p = opt.dropout # set dropout + model = model.to(device) + names = trainloader.dataset.classes # class names + model.names = names # attach class names + + # Info + if RANK in {-1, 0}: + model_info(model) + if opt.verbose: + LOGGER.info(model) + images, labels = next(iter(trainloader)) + file = imshow_cls(images[:25], labels[:25], names=names, f=save_dir / 'train_images.jpg') + logger.log_images(file, name='Train Examples') + logger.log_graph(model, imgsz) # log model + + # Optimizer + optimizer = smart_optimizer(model, opt.optimizer, opt.lr0, momentum=0.9, decay=5e-5) + + # Scheduler + lrf = 0.01 # final lr (fraction of lr0) + # lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine + lf = lambda x: (1 - x / epochs) * (1 - lrf) + lrf # linear + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) + # scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr0, total_steps=epochs, pct_start=0.1, + # final_div_factor=1 / 25 / lrf) + + # EMA + ema = ModelEMA(model) if RANK in {-1, 0} else None + + # DDP mode + if cuda and RANK != -1: + model = smart_DDP(model) + + # Train + t0 = time.time() + criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function + best_fitness = 0.0 + scaler = amp.GradScaler(enabled=cuda) + val = test_dir.stem # 'val' or 'test' + LOGGER.info(f'Image sizes {imgsz} train, {imgsz} test\n' + f'Using {nw * WORLD_SIZE} dataloader workers\n' + f"Logging results to {colorstr('bold', save_dir)}\n" + f'Starting {opt.model} training on {data} dataset with {nc} classes for {epochs} epochs...\n\n' + f"{'Epoch':>10}{'GPU_mem':>10}{'train_loss':>12}{f'{val}_loss':>12}{'top1_acc':>12}{'top5_acc':>12}") + for epoch in range(epochs): # loop over the dataset multiple times + tloss, vloss, fitness = 0.0, 0.0, 0.0 # train loss, val loss, fitness + model.train() + if RANK != -1: + trainloader.sampler.set_epoch(epoch) + pbar = enumerate(trainloader) + if RANK in {-1, 0}: + pbar = tqdm(enumerate(trainloader), total=len(trainloader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') + for i, (images, labels) in pbar: # progress bar + images, labels = images.to(device, non_blocking=True), labels.to(device) + + # Forward + with amp.autocast(enabled=cuda): # stability issues when enabled + loss = criterion(model(images), labels) + + # Backward + scaler.scale(loss).backward() + + # Optimize + scaler.unscale_(optimizer) # unscale gradients + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + if ema: + ema.update(model) + + if RANK in {-1, 0}: + # Print + tloss = (tloss * i + loss.item()) / (i + 1) # update mean losses + mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) + pbar.desc = f"{f'{epoch + 1}/{epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 + + # Test + if i == len(pbar) - 1: # last batch + top1, top5, vloss = validate.run(model=ema.ema, + dataloader=testloader, + criterion=criterion, + pbar=pbar) # test accuracy, loss + fitness = top1 # define fitness as top1 accuracy + + # Scheduler + scheduler.step() + + # Log metrics + if RANK in {-1, 0}: + # Best fitness + if fitness > best_fitness: + best_fitness = fitness + + # Log + metrics = { + "train/loss": tloss, + f"{val}/loss": vloss, + "metrics/accuracy_top1": top1, + "metrics/accuracy_top5": top5, + "lr/0": optimizer.param_groups[0]['lr']} # learning rate + logger.log_metrics(metrics, epoch) + + # Save model + final_epoch = epoch + 1 == epochs + if (not opt.nosave) or final_epoch: + ckpt = { + 'epoch': epoch, + 'best_fitness': best_fitness, + 'model': deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(), + 'ema': None, # deepcopy(ema.ema).half(), + 'updates': ema.updates, + 'optimizer': None, # optimizer.state_dict(), + 'opt': vars(opt), + 'date': datetime.now().isoformat()} + + # Save last, best and delete + torch.save(ckpt, last) + if best_fitness == fitness: + torch.save(ckpt, best) + del ckpt + + # Train complete + if RANK in {-1, 0} and final_epoch: + LOGGER.info(f'\nTraining complete ({(time.time() - t0) / 3600:.3f} hours)' + f"\nResults saved to {colorstr('bold', save_dir)}" + f"\nPredict: python classify/predict.py --weights {best} --source im.jpg" + f"\nValidate: python classify/val.py --weights {best} --data {data_dir}" + f"\nExport: python export.py --weights {best} --include onnx" + f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{best}')" + f"\nVisualize: https://netron.app\n") + + # Plot examples + images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels + pred = torch.max(ema.ema((images.half() if cuda else images.float()).to(device)), 1)[1] + file = imshow_cls(images, labels, pred, names, verbose=False, f=save_dir / 'test_images.jpg') + + # Log results + meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()} + logger.log_images(file, name='Test Examples (true-predicted)', epoch=epoch) + logger.log_model(best, epochs, metadata=meta) + + +def parse_opt(known=False): + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, default='yolov5s-cls.pt', help='initial weights path') + parser.add_argument('--data', type=str, default='mnist', help='cifar10, cifar100, mnist, imagenet, etc.') + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64, help='total batch size for all GPUs') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=128, help='train, val image size (pixels)') + parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') + parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') + parser.add_argument('--project', default=ROOT / 'runs/train-cls', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--pretrained', nargs='?', const=True, default=True, help='start from i.e. --pretrained False') + parser.add_argument('--optimizer', choices=['SGD', 'Adam', 'AdamW', 'RMSProp'], default='Adam', help='optimizer') + parser.add_argument('--lr0', type=float, default=0.001, help='initial learning rate') + parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing epsilon') + parser.add_argument('--cutoff', type=int, default=None, help='Model layer cutoff index for Classify() head') + parser.add_argument('--dropout', type=float, default=None, help='Dropout (fraction)') + parser.add_argument('--verbose', action='store_true', help='Verbose mode') + parser.add_argument('--seed', type=int, default=0, help='Global training seed') + parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') + return parser.parse_known_args()[0] if known else parser.parse_args() + + +def main(opt): + # Checks + if RANK in {-1, 0}: + print_args(vars(opt)) + check_git_status() + check_requirements() + + # DDP mode + device = select_device(opt.device, batch_size=opt.batch_size) + if LOCAL_RANK != -1: + assert opt.batch_size != -1, 'AutoBatch is coming soon for classification, please pass a valid --batch-size' + assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE' + assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' + torch.cuda.set_device(LOCAL_RANK) + device = torch.device('cuda', LOCAL_RANK) + dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo") + + # Parameters + opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run + + # Train + train(opt, device) + + +def run(**kwargs): + # Usage: from yolov5 import classify; classify.train.run(data=mnist, imgsz=320, model='yolov5m') + opt = parse_opt(True) + for k, v in kwargs.items(): + setattr(opt, k, v) + main(opt) + return opt + + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/classify/val.py b/classify/val.py new file mode 100644 index 000000000000..0930ba8c9c51 --- /dev/null +++ b/classify/val.py @@ -0,0 +1,158 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Validate a classification model on a dataset + +Usage: + $ python classify/val.py --weights yolov5s-cls.pt --data ../datasets/imagenet +""" + +import argparse +import os +import sys +from pathlib import Path + +import torch +from tqdm import tqdm + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +from models.common import DetectMultiBackend +from utils.dataloaders import create_classification_dataloader +from utils.general import LOGGER, check_img_size, check_requirements, colorstr, increment_path, print_args +from utils.torch_utils import select_device, smart_inference_mode, time_sync + + +@smart_inference_mode() +def run( + data=ROOT / '../datasets/mnist', # dataset dir + weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s) + batch_size=128, # batch size + imgsz=224, # inference size (pixels) + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + workers=8, # max dataloader workers (per RANK in DDP mode) + verbose=False, # verbose output + project=ROOT / 'runs/val-cls', # save to project/name + name='exp', # save to project/name + exist_ok=False, # existing project/name ok, do not increment + half=True, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + model=None, + dataloader=None, + criterion=None, + pbar=None, +): + # Initialize/load model and set device + training = model is not None + if training: # called by train.py + device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model + half &= device.type != 'cpu' # half precision only supported on CUDA + model.half() if half else model.float() + else: # called directly + device = select_device(device, batch_size=batch_size) + + # Directories + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run + save_dir.mkdir(parents=True, exist_ok=True) # make dir + + # Load model + model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half) + stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine + imgsz = check_img_size(imgsz, s=stride) # check image size + half = model.fp16 # FP16 supported on limited backends with CUDA + if engine: + batch_size = model.batch_size + else: + device = model.device + if not (pt or jit): + batch_size = 1 # export.py models default to batch-size 1 + LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') + + # Dataloader + data = Path(data) + test_dir = data / 'test' if (data / 'test').exists() else data / 'val' # data/test or data/val + dataloader = create_classification_dataloader(path=test_dir, + imgsz=imgsz, + batch_size=batch_size, + augment=False, + rank=-1, + workers=workers) + + model.eval() + pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0] + n = len(dataloader) # number of batches + action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing' + desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}" + bar = tqdm(dataloader, desc, n, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}', position=0) + with torch.cuda.amp.autocast(enabled=device.type != 'cpu'): + for images, labels in bar: + t1 = time_sync() + images, labels = images.to(device, non_blocking=True), labels.to(device) + t2 = time_sync() + dt[0] += t2 - t1 + + y = model(images) + t3 = time_sync() + dt[1] += t3 - t2 + + pred.append(y.argsort(1, descending=True)[:, :5]) + targets.append(labels) + if criterion: + loss += criterion(y, labels) + dt[2] += time_sync() - t3 + + loss /= n + pred, targets = torch.cat(pred), torch.cat(targets) + correct = (targets[:, None] == pred).float() + acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy + top1, top5 = acc.mean(0).tolist() + + if pbar: + pbar.desc = f"{pbar.desc[:-36]}{loss:>12.3g}{top1:>12.3g}{top5:>12.3g}" + if verbose: # all classes + LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}") + LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}") + for i, c in enumerate(model.names): + aci = acc[targets == i] + top1i, top5i = aci.mean(0).tolist() + LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}") + + # Print results + t = tuple(x / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image + shape = (1, 3, imgsz, imgsz) + LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t) + LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") + + return top1, top5, loss + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, default=ROOT / '../datasets/mnist', help='dataset path') + parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model.pt path(s)') + parser.add_argument('--batch-size', type=int, default=128, help='batch size') + parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='inference size (pixels)') + parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') + parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') + parser.add_argument('--verbose', nargs='?', const=True, default=True, help='verbose output') + parser.add_argument('--project', default=ROOT / 'runs/val-cls', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') + parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + opt = parser.parse_args() + print_args(vars(opt)) + return opt + + +def main(opt): + check_requirements(exclude=('tensorboard', 'thop')) + run(**vars(opt)) + + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/data/ImageNet.yaml b/data/ImageNet.yaml new file mode 100644 index 000000000000..9f89b4268aff --- /dev/null +++ b/data/ImageNet.yaml @@ -0,0 +1,156 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +# ImageNet-1k dataset https://www.image-net.org/index.php by Stanford University +# Simplified class names from https://github.com/anishathalye/imagenet-simple-labels +# Example usage: python classify/train.py --data imagenet +# parent +# ├── yolov5 +# └── datasets +# └── imagenet ← downloads here (144 GB) + + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/imagenet # dataset root dir +train: train # train images (relative to 'path') 1281167 images +val: val # val images (relative to 'path') 50000 images +test: # test images (optional) + +# Classes +nc: 1000 # number of classes +names: ['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead shark', 'electric ray', 'stingray', 'cock', + 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'American robin', + 'bulbul', 'jay', 'magpie', 'chickadee', 'American dipper', 'kite', 'bald eagle', 'vulture', 'great grey owl', + 'fire salamander', 'smooth newt', 'newt', 'spotted salamander', 'axolotl', 'American bullfrog', 'tree frog', + 'tailed frog', 'loggerhead sea turtle', 'leatherback sea turtle', 'mud turtle', 'terrapin', 'box turtle', + 'banded gecko', 'green iguana', 'Carolina anole', 'desert grassland whiptail lizard', 'agama', + 'frilled-necked lizard', 'alligator lizard', 'Gila monster', 'European green lizard', 'chameleon', + 'Komodo dragon', 'Nile crocodile', 'American alligator', 'triceratops', 'worm snake', 'ring-necked snake', + 'eastern hog-nosed snake', 'smooth green snake', 'kingsnake', 'garter snake', 'water snake', 'vine snake', + 'night snake', 'boa constrictor', 'African rock python', 'Indian cobra', 'green mamba', 'sea snake', + 'Saharan horned viper', 'eastern diamondback rattlesnake', 'sidewinder', 'trilobite', 'harvestman', 'scorpion', + 'yellow garden spider', 'barn spider', 'European garden spider', 'southern black widow', 'tarantula', + 'wolf spider', 'tick', 'centipede', 'black grouse', 'ptarmigan', 'ruffed grouse', 'prairie grouse', 'peacock', + 'quail', 'partridge', 'grey parrot', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', + 'hornbill', 'hummingbird', 'jacamar', 'toucan', 'duck', 'red-breasted merganser', 'goose', 'black swan', + 'tusker', 'echidna', 'platypus', 'wallaby', 'koala', 'wombat', 'jellyfish', 'sea anemone', 'brain coral', + 'flatworm', 'nematode', 'conch', 'snail', 'slug', 'sea slug', 'chiton', 'chambered nautilus', 'Dungeness crab', + 'rock crab', 'fiddler crab', 'red king crab', 'American lobster', 'spiny lobster', 'crayfish', 'hermit crab', + 'isopod', 'white stork', 'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'great egret', 'bittern', + 'crane (bird)', 'limpkin', 'common gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'dunlin', + 'common redshank', 'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross', 'grey whale', + 'killer whale', 'dugong', 'sea lion', 'Chihuahua', 'Japanese Chin', 'Maltese', 'Pekingese', 'Shih Tzu', + 'King Charles Spaniel', 'Papillon', 'toy terrier', 'Rhodesian Ridgeback', 'Afghan Hound', 'Basset Hound', + 'Beagle', 'Bloodhound', 'Bluetick Coonhound', 'Black and Tan Coonhound', 'Treeing Walker Coonhound', + 'English foxhound', 'Redbone Coonhound', 'borzoi', 'Irish Wolfhound', 'Italian Greyhound', 'Whippet', + 'Ibizan Hound', 'Norwegian Elkhound', 'Otterhound', 'Saluki', 'Scottish Deerhound', 'Weimaraner', + 'Staffordshire Bull Terrier', 'American Staffordshire Terrier', 'Bedlington Terrier', 'Border Terrier', + 'Kerry Blue Terrier', 'Irish Terrier', 'Norfolk Terrier', 'Norwich Terrier', 'Yorkshire Terrier', + 'Wire Fox Terrier', 'Lakeland Terrier', 'Sealyham Terrier', 'Airedale Terrier', 'Cairn Terrier', + 'Australian Terrier', 'Dandie Dinmont Terrier', 'Boston Terrier', 'Miniature Schnauzer', 'Giant Schnauzer', + 'Standard Schnauzer', 'Scottish Terrier', 'Tibetan Terrier', 'Australian Silky Terrier', + 'Soft-coated Wheaten Terrier', 'West Highland White Terrier', 'Lhasa Apso', 'Flat-Coated Retriever', + 'Curly-coated Retriever', 'Golden Retriever', 'Labrador Retriever', 'Chesapeake Bay Retriever', + 'German Shorthaired Pointer', 'Vizsla', 'English Setter', 'Irish Setter', 'Gordon Setter', 'Brittany', + 'Clumber Spaniel', 'English Springer Spaniel', 'Welsh Springer Spaniel', 'Cocker Spaniels', 'Sussex Spaniel', + 'Irish Water Spaniel', 'Kuvasz', 'Schipperke', 'Groenendael', 'Malinois', 'Briard', 'Australian Kelpie', + 'Komondor', 'Old English Sheepdog', 'Shetland Sheepdog', 'collie', 'Border Collie', 'Bouvier des Flandres', + 'Rottweiler', 'German Shepherd Dog', 'Dobermann', 'Miniature Pinscher', 'Greater Swiss Mountain Dog', + 'Bernese Mountain Dog', 'Appenzeller Sennenhund', 'Entlebucher Sennenhund', 'Boxer', 'Bullmastiff', + 'Tibetan Mastiff', 'French Bulldog', 'Great Dane', 'St. Bernard', 'husky', 'Alaskan Malamute', 'Siberian Husky', + 'Dalmatian', 'Affenpinscher', 'Basenji', 'pug', 'Leonberger', 'Newfoundland', 'Pyrenean Mountain Dog', + 'Samoyed', 'Pomeranian', 'Chow Chow', 'Keeshond', 'Griffon Bruxellois', 'Pembroke Welsh Corgi', + 'Cardigan Welsh Corgi', 'Toy Poodle', 'Miniature Poodle', 'Standard Poodle', 'Mexican hairless dog', + 'grey wolf', 'Alaskan tundra wolf', 'red wolf', 'coyote', 'dingo', 'dhole', 'African wild dog', 'hyena', + 'red fox', 'kit fox', 'Arctic fox', 'grey fox', 'tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat', + 'Egyptian Mau', 'cougar', 'lynx', 'leopard', 'snow leopard', 'jaguar', 'lion', 'tiger', 'cheetah', 'brown bear', + 'American black bear', 'polar bear', 'sloth bear', 'mongoose', 'meerkat', 'tiger beetle', 'ladybug', + 'ground beetle', 'longhorn beetle', 'leaf beetle', 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', + 'ant', 'grasshopper', 'cricket', 'stick insect', 'cockroach', 'mantis', 'cicada', 'leafhopper', 'lacewing', + 'dragonfly', 'damselfly', 'red admiral', 'ringlet', 'monarch butterfly', 'small white', 'sulphur butterfly', + 'gossamer-winged butterfly', 'starfish', 'sea urchin', 'sea cucumber', 'cottontail rabbit', 'hare', + 'Angora rabbit', 'hamster', 'porcupine', 'fox squirrel', 'marmot', 'beaver', 'guinea pig', 'common sorrel', + 'zebra', 'pig', 'wild boar', 'warthog', 'hippopotamus', 'ox', 'water buffalo', 'bison', 'ram', 'bighorn sheep', + 'Alpine ibex', 'hartebeest', 'impala', 'gazelle', 'dromedary', 'llama', 'weasel', 'mink', 'European polecat', + 'black-footed ferret', 'otter', 'skunk', 'badger', 'armadillo', 'three-toed sloth', 'orangutan', 'gorilla', + 'chimpanzee', 'gibbon', 'siamang', 'guenon', 'patas monkey', 'baboon', 'macaque', 'langur', + 'black-and-white colobus', 'proboscis monkey', 'marmoset', 'white-headed capuchin', 'howler monkey', 'titi', + "Geoffroy's spider monkey", 'common squirrel monkey', 'ring-tailed lemur', 'indri', 'Asian elephant', + 'African bush elephant', 'red panda', 'giant panda', 'snoek', 'eel', 'coho salmon', 'rock beauty', 'clownfish', + 'sturgeon', 'garfish', 'lionfish', 'pufferfish', 'abacus', 'abaya', 'academic gown', 'accordion', + 'acoustic guitar', 'aircraft carrier', 'airliner', 'airship', 'altar', 'ambulance', 'amphibious vehicle', + 'analog clock', 'apiary', 'apron', 'waste container', 'assault rifle', 'backpack', 'bakery', 'balance beam', + 'balloon', 'ballpoint pen', 'Band-Aid', 'banjo', 'baluster', 'barbell', 'barber chair', 'barbershop', 'barn', + 'barometer', 'barrel', 'wheelbarrow', 'baseball', 'basketball', 'bassinet', 'bassoon', 'swimming cap', + 'bath towel', 'bathtub', 'station wagon', 'lighthouse', 'beaker', 'military cap', 'beer bottle', 'beer glass', + 'bell-cot', 'bib', 'tandem bicycle', 'bikini', 'ring binder', 'binoculars', 'birdhouse', 'boathouse', + 'bobsleigh', 'bolo tie', 'poke bonnet', 'bookcase', 'bookstore', 'bottle cap', 'bow', 'bow tie', 'brass', 'bra', + 'breakwater', 'breastplate', 'broom', 'bucket', 'buckle', 'bulletproof vest', 'high-speed train', + 'butcher shop', 'taxicab', 'cauldron', 'candle', 'cannon', 'canoe', 'can opener', 'cardigan', 'car mirror', + 'carousel', 'tool kit', 'carton', 'car wheel', 'automated teller machine', 'cassette', 'cassette player', + 'castle', 'catamaran', 'CD player', 'cello', 'mobile phone', 'chain', 'chain-link fence', 'chain mail', + 'chainsaw', 'chest', 'chiffonier', 'chime', 'china cabinet', 'Christmas stocking', 'church', 'movie theater', + 'cleaver', 'cliff dwelling', 'cloak', 'clogs', 'cocktail shaker', 'coffee mug', 'coffeemaker', 'coil', + 'combination lock', 'computer keyboard', 'confectionery store', 'container ship', 'convertible', 'corkscrew', + 'cornet', 'cowboy boot', 'cowboy hat', 'cradle', 'crane (machine)', 'crash helmet', 'crate', 'infant bed', + 'Crock Pot', 'croquet ball', 'crutch', 'cuirass', 'dam', 'desk', 'desktop computer', 'rotary dial telephone', + 'diaper', 'digital clock', 'digital watch', 'dining table', 'dishcloth', 'dishwasher', 'disc brake', 'dock', + 'dog sled', 'dome', 'doormat', 'drilling rig', 'drum', 'drumstick', 'dumbbell', 'Dutch oven', 'electric fan', + 'electric guitar', 'electric locomotive', 'entertainment center', 'envelope', 'espresso machine', 'face powder', + 'feather boa', 'filing cabinet', 'fireboat', 'fire engine', 'fire screen sheet', 'flagpole', 'flute', + 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', 'four-poster bed', 'freight car', + 'French horn', 'frying pan', 'fur coat', 'garbage truck', 'gas mask', 'gas pump', 'goblet', 'go-kart', + 'golf ball', 'golf cart', 'gondola', 'gong', 'gown', 'grand piano', 'greenhouse', 'grille', 'grocery store', + 'guillotine', 'barrette', 'hair spray', 'half-track', 'hammer', 'hamper', 'hair dryer', 'hand-held computer', + 'handkerchief', 'hard disk drive', 'harmonica', 'harp', 'harvester', 'hatchet', 'holster', 'home theater', + 'honeycomb', 'hook', 'hoop skirt', 'horizontal bar', 'horse-drawn vehicle', 'hourglass', 'iPod', 'clothes iron', + "jack-o'-lantern", 'jeans', 'jeep', 'T-shirt', 'jigsaw puzzle', 'pulled rickshaw', 'joystick', 'kimono', + 'knee pad', 'knot', 'lab coat', 'ladle', 'lampshade', 'laptop computer', 'lawn mower', 'lens cap', + 'paper knife', 'library', 'lifeboat', 'lighter', 'limousine', 'ocean liner', 'lipstick', 'slip-on shoe', + 'lotion', 'speaker', 'loupe', 'sawmill', 'magnetic compass', 'mail bag', 'mailbox', 'tights', 'tank suit', + 'manhole cover', 'maraca', 'marimba', 'mask', 'match', 'maypole', 'maze', 'measuring cup', 'medicine chest', + 'megalith', 'microphone', 'microwave oven', 'military uniform', 'milk can', 'minibus', 'miniskirt', 'minivan', + 'missile', 'mitten', 'mixing bowl', 'mobile home', 'Model T', 'modem', 'monastery', 'monitor', 'moped', + 'mortar', 'square academic cap', 'mosque', 'mosquito net', 'scooter', 'mountain bike', 'tent', 'computer mouse', + 'mousetrap', 'moving van', 'muzzle', 'nail', 'neck brace', 'necklace', 'nipple', 'notebook computer', 'obelisk', + 'oboe', 'ocarina', 'odometer', 'oil filter', 'organ', 'oscilloscope', 'overskirt', 'bullock cart', + 'oxygen mask', 'packet', 'paddle', 'paddle wheel', 'padlock', 'paintbrush', 'pajamas', 'palace', 'pan flute', + 'paper towel', 'parachute', 'parallel bars', 'park bench', 'parking meter', 'passenger car', 'patio', + 'payphone', 'pedestal', 'pencil case', 'pencil sharpener', 'perfume', 'Petri dish', 'photocopier', 'plectrum', + 'Pickelhaube', 'picket fence', 'pickup truck', 'pier', 'piggy bank', 'pill bottle', 'pillow', 'ping-pong ball', + 'pinwheel', 'pirate ship', 'pitcher', 'hand plane', 'planetarium', 'plastic bag', 'plate rack', 'plow', + 'plunger', 'Polaroid camera', 'pole', 'police van', 'poncho', 'billiard table', 'soda bottle', 'pot', + "potter's wheel", 'power drill', 'prayer rug', 'printer', 'prison', 'projectile', 'projector', 'hockey puck', + 'punching bag', 'purse', 'quill', 'quilt', 'race car', 'racket', 'radiator', 'radio', 'radio telescope', + 'rain barrel', 'recreational vehicle', 'reel', 'reflex camera', 'refrigerator', 'remote control', 'restaurant', + 'revolver', 'rifle', 'rocking chair', 'rotisserie', 'eraser', 'rugby ball', 'ruler', 'running shoe', 'safe', + 'safety pin', 'salt shaker', 'sandal', 'sarong', 'saxophone', 'scabbard', 'weighing scale', 'school bus', + 'schooner', 'scoreboard', 'CRT screen', 'screw', 'screwdriver', 'seat belt', 'sewing machine', 'shield', + 'shoe store', 'shoji', 'shopping basket', 'shopping cart', 'shovel', 'shower cap', 'shower curtain', 'ski', + 'ski mask', 'sleeping bag', 'slide rule', 'sliding door', 'slot machine', 'snorkel', 'snowmobile', 'snowplow', + 'soap dispenser', 'soccer ball', 'sock', 'solar thermal collector', 'sombrero', 'soup bowl', 'space bar', + 'space heater', 'space shuttle', 'spatula', 'motorboat', 'spider web', 'spindle', 'sports car', 'spotlight', + 'stage', 'steam locomotive', 'through arch bridge', 'steel drum', 'stethoscope', 'scarf', 'stone wall', + 'stopwatch', 'stove', 'strainer', 'tram', 'stretcher', 'couch', 'stupa', 'submarine', 'suit', 'sundial', + 'sunglass', 'sunglasses', 'sunscreen', 'suspension bridge', 'mop', 'sweatshirt', 'swimsuit', 'swing', 'switch', + 'syringe', 'table lamp', 'tank', 'tape player', 'teapot', 'teddy bear', 'television', 'tennis ball', + 'thatched roof', 'front curtain', 'thimble', 'threshing machine', 'throne', 'tile roof', 'toaster', + 'tobacco shop', 'toilet seat', 'torch', 'totem pole', 'tow truck', 'toy store', 'tractor', 'semi-trailer truck', + 'tray', 'trench coat', 'tricycle', 'trimaran', 'tripod', 'triumphal arch', 'trolleybus', 'trombone', 'tub', + 'turnstile', 'typewriter keyboard', 'umbrella', 'unicycle', 'upright piano', 'vacuum cleaner', 'vase', 'vault', + 'velvet', 'vending machine', 'vestment', 'viaduct', 'violin', 'volleyball', 'waffle iron', 'wall clock', + 'wallet', 'wardrobe', 'military aircraft', 'sink', 'washing machine', 'water bottle', 'water jug', + 'water tower', 'whiskey jug', 'whistle', 'wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', + 'wing', 'wok', 'wooden spoon', 'wool', 'split-rail fence', 'shipwreck', 'yawl', 'yurt', 'website', 'comic book', + 'crossword', 'traffic sign', 'traffic light', 'dust jacket', 'menu', 'plate', 'guacamole', 'consomme', + 'hot pot', 'trifle', 'ice cream', 'ice pop', 'baguette', 'bagel', 'pretzel', 'cheeseburger', 'hot dog', + 'mashed potato', 'cabbage', 'broccoli', 'cauliflower', 'zucchini', 'spaghetti squash', 'acorn squash', + 'butternut squash', 'cucumber', 'artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith', 'strawberry', + 'orange', 'lemon', 'fig', 'pineapple', 'banana', 'jackfruit', 'custard apple', 'pomegranate', 'hay', + 'carbonara', 'chocolate syrup', 'dough', 'meatloaf', 'pizza', 'pot pie', 'burrito', 'red wine', 'espresso', + 'cup', 'eggnog', 'alp', 'bubble', 'cliff', 'coral reef', 'geyser', 'lakeshore', 'promontory', 'shoal', + 'seashore', 'valley', 'volcano', 'baseball player', 'bridegroom', 'scuba diver', 'rapeseed', 'daisy', + "yellow lady's slipper", 'corn', 'acorn', 'rose hip', 'horse chestnut seed', 'coral fungus', 'agaric', + 'gyromitra', 'stinkhorn mushroom', 'earth star', 'hen-of-the-woods', 'bolete', 'ear', + 'toilet paper'] # class names + +# Download script/URL (optional) +download: data/scripts/get_imagenet.sh diff --git a/data/scripts/download_weights.sh b/data/scripts/download_weights.sh index e9fa65394178..a4f3becfdbeb 100755 --- a/data/scripts/download_weights.sh +++ b/data/scripts/download_weights.sh @@ -1,7 +1,7 @@ #!/bin/bash # YOLOv5 🚀 by Ultralytics, GPL-3.0 license # Download latest models from https://github.com/ultralytics/yolov5/releases -# Example usage: bash path/to/download_weights.sh +# Example usage: bash data/scripts/download_weights.sh # parent # └── yolov5 # ├── yolov5s.pt ← downloads here @@ -11,10 +11,11 @@ python - <=7.0.0 + if device.type == 'cpu': + device = torch.device('cuda:0') Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) logger = trt.Logger(trt.Logger.INFO) with open(w, 'rb') as f, trt.Runtime(logger) as runtime: @@ -398,8 +396,8 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, if dtype == np.float16: fp16 = True shape = tuple(context.get_binding_shape(index)) - data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device) - bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size elif coreml: # CoreML @@ -445,9 +443,16 @@ def wrap_frozen_graph(gd, inputs, outputs): input_details = interpreter.get_input_details() # inputs output_details = interpreter.get_output_details() # outputs elif tfjs: - raise Exception('ERROR: YOLOv5 TF.js inference is not supported') + raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported') else: - raise Exception(f'ERROR: {w} is not a supported format') + raise NotImplementedError(f'ERROR: {w} is not a supported format') + + # class names + if 'names' not in locals(): + names = yaml_load(data)['names'] if data else [f'class{i}' for i in range(999)] + if names[0] == 'n01440764' and len(names) == 1000: # ImageNet + names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names + self.__dict__.update(locals()) # assign all variables to self def forward(self, im, augment=False, visualize=False, val=False): @@ -457,7 +462,9 @@ def forward(self, im, augment=False, visualize=False, val=False): im = im.half() # to FP16 if self.pt: # PyTorch - y = self.model(im, augment=augment, visualize=visualize)[0] + y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im) + if isinstance(y, tuple): + y = y[0] elif self.jit: # TorchScript y = self.model(im)[0] elif self.dnn: # ONNX OpenCV DNN @@ -526,7 +533,7 @@ def warmup(self, imgsz=(1, 3, 640, 640)): self.forward(im) # warmup @staticmethod - def model_type(p='path/to/model.pt'): + def _model_type(p='path/to/model.pt'): # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx from export import export_formats suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes @@ -540,8 +547,7 @@ def model_type(p='path/to/model.pt'): @staticmethod def _load_metadata(f='path/to/meta.yaml'): # Load metadata from meta.yaml if it exists - with open(f, errors='ignore') as f: - d = yaml.safe_load(f) + d = yaml_load(f) return d['stride'], d['names'] # assign stride, names @@ -753,10 +759,13 @@ class Classify(nn.Module): # Classification head, i.e. x(b,c1,20,20) to x(b,c2) def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups super().__init__() - self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) - self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1) - self.flat = nn.Flatten() + c_ = 1280 # efficientnet_b0 size + self.conv = Conv(c1, c_, k, s, autopad(k, p), g) + self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1) + self.drop = nn.Dropout(p=0.0, inplace=True) + self.linear = nn.Linear(c_, c2) # to x(b,c2) def forward(self, x): - z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list - return self.flat(self.conv(z)) # flatten to x(b,c2) + if isinstance(x, list): + x = torch.cat(x, 1) + return self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) diff --git a/models/experimental.py b/models/experimental.py index 0317c7526c99..cb32d01ba46a 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -79,7 +79,9 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): for w in weights if isinstance(weights, list) else [weights]: ckpt = torch.load(attempt_download(w), map_location='cpu') # load ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model - model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode + if not hasattr(ckpt, 'stride'): + ckpt.stride = torch.tensor([32.]) # compatibility update for ResNet etc. + model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode # Compatibility updates for m in model.modules(): @@ -92,11 +94,14 @@ def attempt_load(weights, device=None, inplace=True, fuse=True): elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'): m.recompute_scale_factor = None # torch 1.11.0 compatibility + # Return model if len(model) == 1: - return model[-1] # return model + return model[-1] + + # Return detection ensemble print(f'Ensemble created with {weights}\n') for k in 'names', 'nc', 'yaml': setattr(model, k, getattr(model[0], k)) model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}' - return model # return ensemble + return model diff --git a/models/yolo.py b/models/yolo.py index 307b74844ca0..df4209726e0d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -90,8 +90,64 @@ def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version return grid, anchor_grid -class Model(nn.Module): - # YOLOv5 model +class BaseModel(nn.Module): + # YOLOv5 base model + def forward(self, x, profile=False, visualize=False): + return self._forward_once(x, profile, visualize) # single-scale inference, train + + def _forward_once(self, x, profile=False, visualize=False): + y, dt = [], [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) + return x + + def _profile_one_layer(self, m, x, dt): + c = m == self.model[-1] # is final layer, copy input as inplace fix + o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs + t = time_sync() + for _ in range(10): + m(x.copy() if c else x) + dt.append((time_sync() - t) * 100) + if m == self.model[0]: + LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") + LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}') + if c: + LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") + + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + LOGGER.info('Fusing layers... ') + for m in self.model.modules(): + if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, 'bn') # remove batchnorm + m.forward = m.forward_fuse # update forward + self.info() + return self + + def info(self, verbose=False, img_size=640): # print model information + model_info(self, verbose, img_size) + + def _apply(self, fn): + # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers + self = super()._apply(fn) + m = self.model[-1] # Detect() + if isinstance(m, Detect): + m.stride = fn(m.stride) + m.grid = list(map(fn, m.grid)) + if isinstance(m.anchor_grid, list): + m.anchor_grid = list(map(fn, m.anchor_grid)) + return self + + +class DetectionModel(BaseModel): + # YOLOv5 detection model def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes super().__init__() if isinstance(cfg, dict): @@ -149,19 +205,6 @@ def _forward_augment(self, x): y = self._clip_augmented(y) # clip augmented tails return torch.cat(y, 1), None # augmented inference, train - def _forward_once(self, x, profile=False, visualize=False): - y, dt = [], [] # outputs - for m in self.model: - if m.f != -1: # if not from previous layer - x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers - if profile: - self._profile_one_layer(m, x, dt) - x = m(x) # run - y.append(x if m.i in self.save else None) # save output - if visualize: - feature_visualization(x, m.type, m.i, save_dir=visualize) - return x - def _descale_pred(self, p, flips, scale, img_size): # de-scale predictions following augmented inference (inverse operation) if self.inplace: @@ -190,19 +233,6 @@ def _clip_augmented(self, y): y[-1] = y[-1][:, i:] # small return y - def _profile_one_layer(self, m, x, dt): - c = isinstance(m, Detect) # is final layer, copy input as inplace fix - o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs - t = time_sync() - for _ in range(10): - m(x.copy() if c else x) - dt.append((time_sync() - t) * 100) - if m == self.model[0]: - LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") - LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}') - if c: - LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") - def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency # https://arxiv.org/abs/1708.02002 section 3.3 # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1. @@ -213,41 +243,34 @@ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) - def _print_biases(self): - m = self.model[-1] # Detect() module - for mi in m.m: # from - b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) - LOGGER.info( - ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) - # def _print_weights(self): - # for m in self.model.modules(): - # if type(m) is Bottleneck: - # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights +Model = DetectionModel # retain YOLOv5 'Model' class for backwards compatibility - def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers - LOGGER.info('Fusing layers... ') - for m in self.model.modules(): - if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'): - m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv - delattr(m, 'bn') # remove batchnorm - m.forward = m.forward_fuse # update forward - self.info() - return self - def info(self, verbose=False, img_size=640): # print model information - model_info(self, verbose, img_size) - - def _apply(self, fn): - # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers - self = super()._apply(fn) - m = self.model[-1] # Detect() - if isinstance(m, Detect): - m.stride = fn(m.stride) - m.grid = list(map(fn, m.grid)) - if isinstance(m.anchor_grid, list): - m.anchor_grid = list(map(fn, m.anchor_grid)) - return self +class ClassificationModel(BaseModel): + # YOLOv5 classification model + def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index + super().__init__() + self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg) + + def _from_detection_model(self, model, nc=1000, cutoff=10): + # Create a YOLOv5 classification model from a YOLOv5 detection model + if isinstance(model, DetectMultiBackend): + model = model.model # unwrap DetectMultiBackend + model.model = model.model[:cutoff] # backbone + m = model.model[-1] # last layer + ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module + c = Classify(ch, nc) # Classify() + c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type + model.model[-1] = c # replace + self.model = model.model + self.stride = model.stride + self.save = [] + self.nc = nc + + def _from_yaml(self, cfg): + # Create a YOLOv5 classification model from a *.yaml file + self.model = None def parse_model(d, ch): # model_dict, input_channels(3) @@ -321,7 +344,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) # Options if opt.line_profile: # profile layer by layer - _ = model(im, profile=True) + model(im, profile=True) elif opt.profile: # profile forward-backward results = profile(input=im, ops=[model], n=3) diff --git a/train.py b/train.py index d24ac57df23d..bbb26cdeafeb 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,7 @@ from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, - one_cycle, print_args, print_mutation, strip_optimizer) + one_cycle, print_args, print_mutation, strip_optimizer, yaml_save) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss @@ -81,10 +81,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # Save run settings if not evolve: - with open(save_dir / 'hyp.yaml', 'w') as f: - yaml.safe_dump(hyp, f, sort_keys=False) - with open(save_dir / 'opt.yaml', 'w') as f: - yaml.safe_dump(vars(opt), f, sort_keys=False) + yaml_save(save_dir / 'hyp.yaml', hyp) + yaml_save(save_dir / 'opt.yaml', vars(opt)) # Loggers data_dict = None @@ -484,7 +482,7 @@ def main(opt, callbacks=Callbacks()): if RANK in {-1, 0}: print_args(vars(opt)) check_git_status() - check_requirements(exclude=['thop']) + check_requirements() # Resume if opt.resume and not (check_wandb_resume(opt) or opt.evolve): # resume from specified or most recent last.pt diff --git a/utils/augmentations.py b/utils/augmentations.py index 3f764c06ae3b..a55fefa68a76 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -8,15 +8,21 @@ import cv2 import numpy as np +import torchvision.transforms as T +import torchvision.transforms.functional as TF from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box from utils.metrics import bbox_ioa +IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean +IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation + class Albumentations: # YOLOv5 Albumentations class (optional, only used if package is installed) def __init__(self): self.transform = None + prefix = colorstr('albumentations: ') try: import albumentations as A check_version(A.__version__, '1.0.3', hard=True) # version requirement @@ -31,11 +37,11 @@ def __init__(self): A.ImageCompression(quality_lower=75, p=0.0)] # transforms self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])) - LOGGER.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p)) + LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p)) except ImportError: # package not installed, skip pass except Exception as e: - LOGGER.info(colorstr('albumentations: ') + f'{e}') + LOGGER.info(f'{prefix}{e}') def __call__(self, im, labels, p=1.0): if self.transform and random.random() < p: @@ -44,6 +50,18 @@ def __call__(self, im, labels, p=1.0): return im, labels +def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False): + # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std + return TF.normalize(x, mean, std, inplace=inplace) + + +def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD): + # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean + for i in range(3): + x[:, i] = x[:, i] * std[i] + mean[i] + return x + + def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): # HSV color-space augmentation if hgain or sgain or vgain: @@ -282,3 +300,48 @@ def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): w2, h2 = box2[2] - box2[0], box2[3] - box2[1] ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates + + +def classify_albumentations(augment=True, + size=224, + scale=(0.08, 1.0), + hflip=0.5, + vflip=0.0, + jitter=0.4, + mean=IMAGENET_MEAN, + std=IMAGENET_STD, + auto_aug=False): + # YOLOv5 classification Albumentations (optional, only used if package is installed) + prefix = colorstr('albumentations: ') + try: + import albumentations as A + from albumentations.pytorch import ToTensorV2 + check_version(A.__version__, '1.0.3', hard=True) # version requirement + if augment: # Resize and crop + T = [A.RandomResizedCrop(height=size, width=size, scale=scale)] + if auto_aug: + # TODO: implement AugMix, AutoAug & RandAug in albumentation + LOGGER.info(f'{prefix}auto augmentations are currently not supported') + else: + if hflip > 0: + T += [A.HorizontalFlip(p=hflip)] + if vflip > 0: + T += [A.VerticalFlip(p=vflip)] + if jitter > 0: + color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue + T += [A.ColorJitter(*color_jitter, 0)] + else: # Use fixed crop for eval set (reproducibility) + T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)] + T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor + LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p)) + return A.Compose(T) + + except ImportError: # package not installed, skip + pass + except Exception as e: + LOGGER.info(f'{prefix}{e}') + + +def classify_transforms(size=224): + # Transforms to apply if albumentations not installed + return T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 00f6413df7ad..2c04040bf25d 100755 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -22,12 +22,14 @@ import numpy as np import torch import torch.nn.functional as F +import torchvision import yaml from PIL import ExifTags, Image, ImageOps from torch.utils.data import DataLoader, Dataset, dataloader, distributed from tqdm import tqdm -from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective +from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste, + letterbox, mixup, random_perspective) from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn) from utils.torch_utils import torch_distributed_zero_first @@ -870,7 +872,7 @@ def flatten_recursive(path=DATASETS_DIR / 'coco128'): def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes() # Convert detection dataset into classification dataset, with one directory per class path = Path(path) # images dir - shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing + shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing files = list(path.rglob('*.*')) n = len(files) # number of files for im_file in tqdm(files, total=n): @@ -1090,3 +1092,65 @@ def process_images(self): pass print(f'Done. All images saved to {self.im_dir}') return self.im_dir + + +# Classification dataloaders ------------------------------------------------------------------------------------------- +class ClassificationDataset(torchvision.datasets.ImageFolder): + """ + YOLOv5 Classification Dataset. + Arguments + root: Dataset path + transform: torchvision transforms, used by default + album_transform: Albumentations transforms, used if installed + """ + + def __init__(self, root, augment, imgsz, cache=False): + super().__init__(root=root) + self.torch_transforms = classify_transforms(imgsz) + self.album_transforms = classify_albumentations(augment, imgsz) if augment else None + self.cache_ram = cache is True or cache == 'ram' + self.cache_disk = cache == 'disk' + self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im + + def __getitem__(self, i): + f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image + if self.album_transforms: + if self.cache_ram and im is None: + im = self.samples[i][3] = cv2.imread(f) + elif self.cache_disk: + if not fn.exists(): # load npy + np.save(fn.as_posix(), cv2.imread(f)) + im = np.load(fn) + else: # read image + im = cv2.imread(f) # BGR + sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] + else: + sample = self.torch_transforms(self.loader(f)) + return sample, j + + +def create_classification_dataloader(path, + imgsz=224, + batch_size=16, + augment=True, + cache=False, + rank=-1, + workers=8, + shuffle=True): + # Returns Dataloader object to be used with YOLOv5 Classifier + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache) + batch_size = min(batch_size, len(dataset)) + nd = torch.cuda.device_count() + nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) + sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + generator = torch.Generator() + generator.manual_seed(0) + return InfiniteDataLoader(dataset, + batch_size=batch_size, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=True, + worker_init_fn=seed_worker, + generator=generator) # or DataLoader(persistent_workers=True) diff --git a/utils/general.py b/utils/general.py index 2a3ce37cd853..1c525c45f649 100755 --- a/utils/general.py +++ b/utils/general.py @@ -217,7 +217,11 @@ def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False): if args is None: # get args automatically args, _, _, frm = inspect.getargvalues(x) args = {k: v for k, v in frm.items() if k in args} - s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '') + try: + file = Path(file).resolve().relative_to(ROOT).with_suffix('') + except ValueError: + file = Path(file).stem + s = (f'{file}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '') LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items())) @@ -345,7 +349,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals @try_except def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()): - # Check installed dependencies meet requirements (pass *.txt file or list of packages) + # Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages) prefix = colorstr('red', 'bold', 'requirements:') check_python() # check python version if isinstance(requirements, (str, Path)): # requirements.txt file @@ -549,6 +553,18 @@ def amp_allclose(model, im): return False +def yaml_load(file='data.yaml'): + # Single-line safe yaml loading + with open(file, errors='ignore') as f: + return yaml.safe_load(f) + + +def yaml_save(file='data.yaml', data={}): + # Single-line safe yaml saving + with open(file, 'w') as f: + yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False) + + def url2file(url): # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 0f3eceafd0db..8ec846f8cfac 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -5,6 +5,7 @@ import os import warnings +from pathlib import Path import pkg_resources as pkg import torch @@ -76,7 +77,7 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, self.logger.info(s) if not clearml: prefix = colorstr('ClearML: ') - s = f"{prefix}run 'pip install clearml' to automatically track, visualize and remotely train YOLOv5 🚀 runs in ClearML" + s = f"{prefix}run 'pip install clearml' to automatically track, visualize and remotely train YOLOv5 🚀 in ClearML" self.logger.info(s) # TensorBoard @@ -121,11 +122,8 @@ def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): # Callback runs on train batch end # ni: number integrated batches (since train start) if plots: - if ni == 0: - if self.tb and not self.opt.sync_bn: # --sync known issue https://github.com/ultralytics/yolov5/issues/3754 - with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress jit trace warning - self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) + if ni == 0 and not self.opt.sync_bn and self.tb: + log_tensorboard_graph(self.tb, model, imgsz=list(imgs.shape[2:4])) if ni < 3: f = self.save_dir / f'train_batch{ni}.jpg' # filename plot_images(imgs, targets, paths, f) @@ -233,3 +231,78 @@ def on_params_update(self, params): # params: A dict containing {param: value} pairs if self.wandb: self.wandb.wandb_run.config.update(params, allow_val_change=True) + + +class GenericLogger: + """ + YOLOv5 General purpose logger for non-task specific logging + Usage: from utils.loggers import GenericLogger; logger = GenericLogger(...) + Arguments + opt: Run arguments + console_logger: Console logger + include: loggers to include + """ + + def __init__(self, opt, console_logger, include=('tb', 'wandb')): + # init default loggers + self.save_dir = opt.save_dir + self.include = include + self.console_logger = console_logger + if 'tb' in self.include: + prefix = colorstr('TensorBoard: ') + self.console_logger.info( + f"{prefix}Start with 'tensorboard --logdir {self.save_dir.parent}', view at http://localhost:6006/") + self.tb = SummaryWriter(str(self.save_dir)) + + if wandb and 'wandb' in self.include: + self.wandb = wandb.init(project="YOLOv5-Classifier" if opt.project == "runs/train" else opt.project, + name=None if opt.name == "exp" else opt.name, + config=opt) + else: + self.wandb = None + + def log_metrics(self, metrics_dict, epoch): + # Log metrics dictionary to all loggers + if self.tb: + for k, v in metrics_dict.items(): + self.tb.add_scalar(k, v, epoch) + + if self.wandb: + self.wandb.log(metrics_dict, step=epoch) + + def log_images(self, files, name='Images', epoch=0): + # Log images to all loggers + files = [Path(f) for f in (files if isinstance(files, (tuple, list)) else [files])] # to Path + files = [f for f in files if f.exists()] # filter by exists + + if self.tb: + for f in files: + self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC') + + if self.wandb: + self.wandb.log({name: [wandb.Image(str(f), caption=f.name) for f in files]}, step=epoch) + + def log_graph(self, model, imgsz=(640, 640)): + # Log model graph to all loggers + if self.tb: + log_tensorboard_graph(self.tb, model, imgsz) + + def log_model(self, model_path, epoch=0, metadata={}): + # Log model to all loggers + if self.wandb: + art = wandb.Artifact(name=f"run_{wandb.run.id}_model", type="model", metadata=metadata) + art.add_file(str(model_path)) + wandb.log_artifact(art) + + +def log_tensorboard_graph(tb, model, imgsz=(640, 640)): + # Log model graph to TensorBoard + try: + p = next(model.parameters()) # for device, type + imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand + im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress jit trace warning + tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), []) + except Exception: + print('WARNING: TensorBoard graph visualization failure') diff --git a/utils/plots.py b/utils/plots.py index d050f5d36aba..7417308c4d82 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -388,6 +388,35 @@ def plot_labels(labels, names=(), save_dir=Path('')): plt.close() +def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')): + # Show classification image grid with labels (optional) and predictions (optional) + from utils.augmentations import denormalize + + names = names or [f'class{i}' for i in range(1000)] + blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im), + dim=0) # select batch index 0, block by channels + n = min(len(blocks), nmax) # number of plots + m = min(8, round(n ** 0.5)) # 8 x 8 default + fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols + ax = ax.ravel() if m > 1 else [ax] + # plt.subplots_adjust(wspace=0.05, hspace=0.05) + for i in range(n): + ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0)) + ax[i].axis('off') + if labels is not None: + s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '') + ax[i].set_title(s, fontsize=8, verticalalignment='top') + plt.savefig(f, dpi=300, bbox_inches='tight') + plt.close() + if verbose: + LOGGER.info(f"Saving {f}") + if labels is not None: + LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax])) + if pred is not None: + LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax])) + return f + + def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve() # Plot evolve.csv hyp evolution results evolve_csv = Path(evolve_csv) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 1ceb0aa346e9..1cdbe20f8670 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -42,6 +42,16 @@ def decorate(fn): return decorate +def smartCrossEntropyLoss(label_smoothing=0.0): + # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0 + if check_version(torch.__version__, '1.10.0'): + return nn.CrossEntropyLoss(label_smoothing=label_smoothing) # loss function + else: + if label_smoothing > 0: + LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0') + return nn.CrossEntropyLoss() # loss function + + def smart_DDP(model): # Model DDP creation with checks assert not check_version(torch.__version__, '1.12.0', pinned=True), \ @@ -53,6 +63,28 @@ def smart_DDP(model): return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) +def reshape_classifier_output(model, n=1000): + # Update a TorchVision classification model to class count 'n' if required + from models.common import Classify + name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module + if isinstance(m, Classify): # YOLOv5 Classify() head + if m.linear.out_features != n: + m.linear = nn.Linear(m.linear.in_features, n) + elif isinstance(m, nn.Linear): # ResNet, EfficientNet + if m.out_features != n: + setattr(model, name, nn.Linear(m.in_features, n)) + elif isinstance(m, nn.Sequential): + types = [type(x) for x in m] + if nn.Linear in types: + i = types.index(nn.Linear) # nn.Linear index + if m[i].out_features != n: + m[i] = nn.Linear(m[i].in_features, n) + elif nn.Conv2d in types: + i = types.index(nn.Conv2d) # nn.Conv2d index + if m[i].out_channels != n: + m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias) + + @contextmanager def torch_distributed_zero_first(local_rank: int): # Decorator to make all processes in distributed training wait for each local_master to do something @@ -117,14 +149,13 @@ def time_sync(): def profile(input, ops, n=10, device=None): - # YOLOv5 speed/memory/FLOPs profiler - # - # Usage: - # input = torch.randn(16, 3, 640, 640) - # m1 = lambda x: x * torch.sigmoid(x) - # m2 = nn.SiLU() - # profile(input, [m1, m2], n=100) # profile over 100 iterations - + """ YOLOv5 speed/memory/FLOPs profiler + Usage: + input = torch.randn(16, 3, 640, 640) + m1 = lambda x: x * torch.sigmoid(x) + m2 = nn.SiLU() + profile(input, [m1, m2], n=100) # profile over 100 iterations + """ results = [] if not isinstance(device, torch.device): device = select_device(device) @@ -313,6 +344,18 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): return optimizer +def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs): + # YOLOv5 torch.hub.load() wrapper with smart error/issue handling + if check_version(torch.__version__, '1.9.1'): + kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors + if check_version(torch.__version__, '1.12.0'): + kwargs['trust_repo'] = True # argument required starting in torch 0.12 + try: + return torch.hub.load(repo, model, **kwargs) + except Exception: + return torch.hub.load(repo, model, force_reload=True, **kwargs) + + def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True): # Resume training from a partially trained checkpoint best_fitness = 0.0