Skip to content

Commit

Permalink
YOLOv5 + Albumentations integration (ultralytics#3882)
Browse files Browse the repository at this point in the history
* Albumentations integration

* ToGray p=0.01

* print confirmation

* create instance in dataloader init method

* improved version handling

* transform not defined fix

* assert string update

* create check_version()

* add spaces

* update class comment
  • Loading branch information
glenn-jocher committed Jul 5, 2021
1 parent b160cfe commit d39480f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 27 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ pandas
# extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
# pycocotools>=2.0 # COCO mAP
# albumentations>=1.0.0
thop # FLOPs computation
30 changes: 29 additions & 1 deletion utils/augmentations.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,43 @@
# YOLOv5 image augmentation functions

import logging
import random

import cv2
import math
import numpy as np

from utils.general import segment2box, resample_segments
from utils.general import colorstr, segment2box, resample_segments, check_version
from utils.metrics import bbox_ioa


class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self):
self.transform = None
try:
import albumentations as A
check_version(A.__version__, '1.0.0') # version requirement

self.transform = A.Compose([
A.Blur(p=0.1),
A.MedianBlur(p=0.1),
A.ToGray(p=0.01)],
bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms))
except ImportError: # package not installed, skip
pass
except Exception as e:
logging.info(colorstr('albumentations: ') + f'{e}')

def __call__(self, im, labels, p=1.0):
if self.transform and random.random() < p:
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
return im, labels


def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
# HSV color-space augmentation
if hgain or sgain or vgain:
Expand Down
40 changes: 21 additions & 19 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import Dataset
from tqdm import tqdm

from utils.augmentations import augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segments2boxes, clean_str
from utils.torch_utils import torch_distributed_zero_first
Expand Down Expand Up @@ -372,6 +372,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
self.mosaic_border = [-img_size // 2, -img_size // 2]
self.stride = stride
self.path = path
self.albumentations = Albumentations() if augment else None

try:
f = [] # image files
Expand Down Expand Up @@ -539,42 +540,43 @@ def __getitem__(self, index):
if labels.size: # normalized xywh to pixel xyxy format
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])

if self.augment:
# Augment imagespace
if not mosaic:
if self.augment:
img, labels = random_perspective(img, labels,
degrees=hyp['degrees'],
translate=hyp['translate'],
scale=hyp['scale'],
shear=hyp['shear'],
perspective=hyp['perspective'])

# Augment colorspace
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

# Apply cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)

nL = len(labels) # number of labels
if nL:
nl = len(labels) # number of labels
if nl:
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized

if self.augment:
# flip up-down
# Albumentations
img, labels = self.albumentations(img, labels)

# HSV color-space
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])

# Flip up-down
if random.random() < hyp['flipud']:
img = np.flipud(img)
if nL:
if nl:
labels[:, 2] = 1 - labels[:, 2]

# flip left-right
# Flip left-right
if random.random() < hyp['fliplr']:
img = np.fliplr(img)
if nL:
if nl:
labels[:, 1] = 1 - labels[:, 1]

labels_out = torch.zeros((nL, 6))
if nL:
# Cutouts
# if random.random() < 0.9:
# labels = cutout(img, labels)

labels_out = torch.zeros((nl, 6))
if nl:
labels_out[:, 1:] = torch.from_numpy(labels)

# Convert
Expand Down
17 changes: 10 additions & 7 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import glob
import logging
import math
import os
import platform
import random
Expand All @@ -17,6 +16,7 @@
from subprocess import check_output

import cv2
import math
import numpy as np
import pandas as pd
import pkg_resources as pkg
Expand Down Expand Up @@ -136,13 +136,16 @@ def check_git_status(err_msg=', for updates see https://github.com/ultralytics/y
print(f'{e}{err_msg}')


def check_python(minimum='3.6.2', required=True):
def check_python(minimum='3.6.2'):
# Check current python version vs. required python version
current = platform.python_version()
result = pkg.parse_version(current) >= pkg.parse_version(minimum)
if required:
assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
return result
check_version(platform.python_version(), minimum, name='Python ')


def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
# Check version vs. required version
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum)
assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'


def check_requirements(requirements='requirements.txt', exclude=()):
Expand Down

0 comments on commit d39480f

Please sign in to comment.