From 4f39476d8913de7cf8d7fdbf71945b62332e9501 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 30 Jun 2021 19:45:02 +0200 Subject: [PATCH 1/5] Copy-paste augmentation initial commit --- utils/datasets.py | 42 ++++++++++++++++++++++++------------------ utils/metrics.py | 26 +++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index 5baf9c5b1906..bc8020dd174d 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -25,6 +25,7 @@ from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \ xyn2xy, segment2box, segments2boxes, resample_segments, clean_str +from utils.metrics import bbox_ioa from utils.torch_utils import torch_distributed_zero_first # Parameters @@ -683,6 +684,7 @@ def load_mosaic(self, index): # img4, labels4 = replicate(img4, labels4) # replicate # Augment + img4, labels4, segments4 = copy_paste(img4, labels4, segments4) img4, labels4 = random_perspective(img4, labels4, segments4, degrees=self.hyp['degrees'], translate=self.hyp['translate'], @@ -907,6 +909,28 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s return img, targets +def copy_paste(img, labels, segments, fraction=0.5): + # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels in labels in xyxy + h, w, c = img.shape # height, width, channels + im_new = np.zeros(img.shape, np.uint8) + n = len(segments) + labels = labels.tolist() + for j in random.sample(range(n), k=round(fraction * n)): + l, s = labels[j], segments[j] + labels.append([l[0], w - l[3], l[2], w - l[1], l[4]]) + segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) + cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) + + result = cv2.bitwise_and(src1=img, src2=im_new) + result = cv2.flip(result, 1) # flip left-right + + i = result > 0 + # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch + img[i] = result[i] + # cv2.imwrite('debug.jpg', img) # debug + return img, np.array(labels), segments + + def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio w1, h1 = box1[2] - box1[0], box1[3] - box1[1] @@ -919,24 +943,6 @@ def cutout(image, labels): # Applies image cutout augmentation https://arxiv.org/abs/1708.04552 h, w = image.shape[:2] - def bbox_ioa(box1, box2): - # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2 - box2 = box2.transpose() - - # Get the coordinates of bounding boxes - b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] - b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] - - # Intersection area - inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \ - (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0) - - # box2 area - box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16 - - # Intersection over box2 area - return inter_area / box2_area - # create random masks scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction for s in scales: diff --git a/utils/metrics.py b/utils/metrics.py index 4f001c046285..c94c4a76a964 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -1,9 +1,9 @@ # Model validation metrics -import math import warnings from pathlib import Path +import math import matplotlib.pyplot as plt import numpy as np import torch @@ -253,6 +253,30 @@ def box_area(box): return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) +def bbox_ioa(box1, box2, eps=1E-7): + """ Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2 + box1: np.array of shape(4) + box2: np.array of shape(nx4) + returns: np.array of shape(n) + """ + + box2 = box2.transpose() + + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + + # Intersection area + inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \ + (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0) + + # box2 area + box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps + + # Intersection over box2 area + return inter_area / box2_area + + def wh_iou(wh1, wh2): # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2 wh1 = wh1[:, None] # [N,1,2] From 87e24489bdb629becf3f850eb45db4d011652998 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 30 Jun 2021 20:11:38 +0200 Subject: [PATCH 2/5] if any segments --- utils/datasets.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index bc8020dd174d..be743dbe5fe5 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -911,23 +911,25 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s def copy_paste(img, labels, segments, fraction=0.5): # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels in labels in xyxy - h, w, c = img.shape # height, width, channels - im_new = np.zeros(img.shape, np.uint8) - n = len(segments) - labels = labels.tolist() - for j in random.sample(range(n), k=round(fraction * n)): - l, s = labels[j], segments[j] - labels.append([l[0], w - l[3], l[2], w - l[1], l[4]]) - segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) - cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) - - result = cv2.bitwise_and(src1=img, src2=im_new) - result = cv2.flip(result, 1) # flip left-right - - i = result > 0 - # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch - img[i] = result[i] - # cv2.imwrite('debug.jpg', img) # debug + if any(segments): + n = len(segments) + h, w, c = img.shape # height, width, channels + im_new = np.zeros(img.shape, np.uint8) + labels = labels.tolist() + for j in random.sample(range(n), k=round(fraction * n)): + l, s = labels[j], segments[j] + labels.append([l[0], w - l[3], l[2], w - l[1], l[4]]) + segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) + cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) + + result = cv2.bitwise_and(src1=img, src2=im_new) + result = cv2.flip(result, 1) # flip left-right + + i = result > 0 + # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch + img[i] = result[i] + # cv2.imwrite('debug.jpg', img) # debug + return img, np.array(labels), segments From 5370d8035e0e06aa28d2133d9d43cac76d3e52d8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 1 Jul 2021 00:15:01 +0200 Subject: [PATCH 3/5] Add obscuration rejection --- utils/datasets.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index be743dbe5fe5..293b875ab855 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -909,28 +909,28 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s return img, targets -def copy_paste(img, labels, segments, fraction=0.5): +def copy_paste(img, labels, segments, probability=0.5): # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels in labels in xyxy - if any(segments): - n = len(segments) + n = len(segments) + if n: h, w, c = img.shape # height, width, channels im_new = np.zeros(img.shape, np.uint8) - labels = labels.tolist() - for j in random.sample(range(n), k=round(fraction * n)): + for j in random.sample(range(n), k=round(probability * n)): l, s = labels[j], segments[j] - labels.append([l[0], w - l[3], l[2], w - l[1], l[4]]) - segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) - cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) + box = w - l[3], l[2], w - l[1], l[4] + ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area + if (ioa < 0.30).all(): # allow 30% obscuration of existing labels + labels = np.concatenate((labels, [[l[0], *box]]), 0) + segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) + cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) result = cv2.bitwise_and(src1=img, src2=im_new) - result = cv2.flip(result, 1) # flip left-right - + result = cv2.flip(result, 1) # augment segments (flip left-right) i = result > 0 # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch - img[i] = result[i] - # cv2.imwrite('debug.jpg', img) # debug + img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug - return img, np.array(labels), segments + return img, labels, segments def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) From efeea7a28869d47769840ca1e53e30d3574abd62 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 1 Jul 2021 00:20:43 +0200 Subject: [PATCH 4/5] Add copy_paste hyperparameter --- data/hyps/hyp.finetune.yaml | 1 + data/hyps/hyp.finetune_objects365.yaml | 1 + data/hyps/hyp.scratch-p6.yaml | 1 + data/hyps/hyp.scratch.yaml | 1 + train.py | 3 ++- utils/datasets.py | 6 +++--- 6 files changed, 9 insertions(+), 4 deletions(-) diff --git a/data/hyps/hyp.finetune.yaml b/data/hyps/hyp.finetune.yaml index a77597741356..237cd5bc19a1 100644 --- a/data/hyps/hyp.finetune.yaml +++ b/data/hyps/hyp.finetune.yaml @@ -36,3 +36,4 @@ flipud: 0.00856 fliplr: 0.5 mosaic: 1.0 mixup: 0.243 +copy_paste: 0.0 diff --git a/data/hyps/hyp.finetune_objects365.yaml b/data/hyps/hyp.finetune_objects365.yaml index 2b104ef2d9bf..435fa7a45119 100644 --- a/data/hyps/hyp.finetune_objects365.yaml +++ b/data/hyps/hyp.finetune_objects365.yaml @@ -26,3 +26,4 @@ flipud: 0.0 fliplr: 0.5 mosaic: 1.0 mixup: 0.0 +copy_paste: 0.0 diff --git a/data/hyps/hyp.scratch-p6.yaml b/data/hyps/hyp.scratch-p6.yaml index faf565423968..fc1d8ebe0876 100644 --- a/data/hyps/hyp.scratch-p6.yaml +++ b/data/hyps/hyp.scratch-p6.yaml @@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability) fliplr: 0.5 # image flip left-right (probability) mosaic: 1.0 # image mosaic (probability) mixup: 0.0 # image mixup (probability) +copy_paste: 0.0 # segment copy-paste (probability) diff --git a/data/hyps/hyp.scratch.yaml b/data/hyps/hyp.scratch.yaml index 44f26b6658ae..b2cf2e32c638 100644 --- a/data/hyps/hyp.scratch.yaml +++ b/data/hyps/hyp.scratch.yaml @@ -31,3 +31,4 @@ flipud: 0.0 # image flip up-down (probability) fliplr: 0.5 # image flip left-right (probability) mosaic: 1.0 # image mosaic (probability) mixup: 0.0 # image mixup (probability) +copy_paste: 0.0 # segment copy-paste (probability) diff --git a/train.py b/train.py index 257be065f641..8c68c5b76d28 100644 --- a/train.py +++ b/train.py @@ -591,7 +591,8 @@ def main(opt): 'flipud': (1, 0.0, 1.0), # image flip up-down (probability) 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability) 'mosaic': (1, 0.0, 1.0), # image mixup (probability) - 'mixup': (1, 0.0, 1.0)} # image mixup (probability) + 'mixup': (1, 0.0, 1.0), # image mixup (probability) + 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability) with open(opt.hyp) as f: hyp = yaml.safe_load(f) # load hyps dict diff --git a/utils/datasets.py b/utils/datasets.py index 293b875ab855..a521aa65732f 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -684,7 +684,7 @@ def load_mosaic(self, index): # img4, labels4 = replicate(img4, labels4) # replicate # Augment - img4, labels4, segments4 = copy_paste(img4, labels4, segments4) + img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste']) img4, labels4 = random_perspective(img4, labels4, segments4, degrees=self.hyp['degrees'], translate=self.hyp['translate'], @@ -912,7 +912,7 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s def copy_paste(img, labels, segments, probability=0.5): # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels in labels in xyxy n = len(segments) - if n: + if probability and n: h, w, c = img.shape # height, width, channels im_new = np.zeros(img.shape, np.uint8) for j in random.sample(range(n), k=round(probability * n)): @@ -926,7 +926,7 @@ def copy_paste(img, labels, segments, probability=0.5): result = cv2.bitwise_and(src1=img, src2=im_new) result = cv2.flip(result, 1) # augment segments (flip left-right) - i = result > 0 + i = result > 0 # pixels to replace # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug From 3248ab1eb34cfa8e7d3e751406a70f5dc02e5974 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 1 Jul 2021 00:22:34 +0200 Subject: [PATCH 5/5] Update comments --- train.py | 2 +- utils/datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 8c68c5b76d28..386f3d90dd73 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,6 @@ import argparse import logging -import math import os import random import sys @@ -16,6 +15,7 @@ from pathlib import Path from threading import Thread +import math import numpy as np import torch.distributed as dist import torch.nn as nn diff --git a/utils/datasets.py b/utils/datasets.py index a521aa65732f..55f046cd56db 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -910,7 +910,7 @@ def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, s def copy_paste(img, labels, segments, probability=0.5): - # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels in labels in xyxy + # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) n = len(segments) if probability and n: h, w, c = img.shape # height, width, channels