From f1c63e2784dffc9c2ccacb9c5a2e753bacf011ab Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 13 Sep 2020 14:03:54 -0700 Subject: [PATCH] add mosaic and warmup to hyperparameters (#931) --- data/hyp.finetune.yaml | 4 ++++ data/hyp.scratch.yaml | 4 ++++ train.py | 18 +++++++++++------- utils/datasets.py | 5 +++-- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/data/hyp.finetune.yaml b/data/hyp.finetune.yaml index 74b55bd9c1e1..fe9cd55019f7 100644 --- a/data/hyp.finetune.yaml +++ b/data/hyp.finetune.yaml @@ -12,6 +12,9 @@ lr0: 0.0032 lrf: 0.12 momentum: 0.843 weight_decay: 0.00036 +warmup_epochs: 2.0 +warmup_momentum: 0.5 +warmup_bias_lr: 0.05 giou: 0.0296 cls: 0.243 cls_pw: 0.631 @@ -31,4 +34,5 @@ shear: 0.602 perspective: 0.0 flipud: 0.00856 fliplr: 0.5 +mosaic: 1.0 mixup: 0.243 diff --git a/data/hyp.scratch.yaml b/data/hyp.scratch.yaml index 3beaf09f5017..9f53e86dd3ab 100644 --- a/data/hyp.scratch.yaml +++ b/data/hyp.scratch.yaml @@ -7,6 +7,9 @@ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf) momentum: 0.937 # SGD momentum/Adam beta1 weight_decay: 0.0005 # optimizer weight decay 5e-4 +warmup_epochs: 3.0 # warmup epochs (fractions ok) +warmup_momentum: 0.8 # warmup initial momentum +warmup_bias_lr: 0.1 # warmup initial bias lr giou: 0.05 # box loss gain cls: 0.5 # cls loss gain cls_pw: 1.0 # cls BCELoss positive_weight @@ -26,4 +29,5 @@ shear: 0.0 # image shear (+/- deg) perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 flipud: 0.0 # image flip up-down (probability) fliplr: 0.5 # image flip left-right (probability) +mosaic: 1.0 # image mosaic (probability) mixup: 0.0 # image mixup (probability) diff --git a/train.py b/train.py index 4aae4cbf3060..79875377d73a 100644 --- a/train.py +++ b/train.py @@ -202,7 +202,7 @@ def train(hyp, opt, device, tb_writer=None): # Start training t0 = time.time() - nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) + nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' @@ -250,9 +250,9 @@ def train(hyp, opt, device, tb_writer=None): accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 - x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) + x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) if 'momentum' in x: - x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']]) + x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']]) # Multi-scale if opt.multi_scale: @@ -460,8 +460,11 @@ def train(hyp, opt, device, tb_writer=None): # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit) meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3) 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) - 'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1 + 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay + 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok) + 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum + 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr 'giou': (1, 0.02, 0.2), # GIoU loss gain 'cls': (1, 0.2, 4.0), # cls loss gain 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight @@ -469,7 +472,7 @@ def train(hyp, opt, device, tb_writer=None): 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight 'iou_t': (0, 0.1, 0.7), # IoU training threshold 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold - 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore) + 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore) 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction) 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction) @@ -481,6 +484,7 @@ def train(hyp, opt, device, tb_writer=None): 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 'flipud': (1, 0.0, 1.0), # image flip up-down (probability) 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability) + 'mosaic': (1, 0.0, 1.0), # image mixup (probability) 'mixup': (1, 0.0, 1.0)} # image mixup (probability) assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' @@ -490,7 +494,7 @@ def train(hyp, opt, device, tb_writer=None): if opt.bucket: os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists - for _ in range(1): # generations to evolve + for _ in range(300): # generations to evolve if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate # Select parent(s) parent = 'single' # parent selection method: 'single' or 'weighted' @@ -505,7 +509,7 @@ def train(hyp, opt, device, tb_writer=None): x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination # Mutate - mp, s = 0.9, 0.2 # mutation probability, sigma + mp, s = 0.8, 0.2 # mutation probability, sigma npr = np.random npr.seed(int(time.time())) g = np.array([x[0] for x in meta.values()]) # gains 0-1 diff --git a/utils/datasets.py b/utils/datasets.py index fd2050d7011c..4be6a81e32af 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -516,7 +516,8 @@ def __getitem__(self, index): index = self.indices[index] hyp = self.hyp - if self.mosaic: + mosaic = self.mosaic and random.random() < hyp['mosaic'] + if mosaic: # Load mosaic img, labels = load_mosaic(self, index) shapes = None @@ -550,7 +551,7 @@ def __getitem__(self, index): if self.augment: # Augment imagespace - if not self.mosaic: + if not mosaic: img, labels = random_perspective(img, labels, degrees=hyp['degrees'], translate=hyp['translate'],