Skip to content

Commit

Permalink
add mosaic and warmup to hyperparameters (#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Sep 13, 2020
1 parent 806e75f commit f1c63e2
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
4 changes: 4 additions & 0 deletions data/hyp.finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ lr0: 0.0032
lrf: 0.12
momentum: 0.843
weight_decay: 0.00036
warmup_epochs: 2.0
warmup_momentum: 0.5
warmup_bias_lr: 0.05
giou: 0.0296
cls: 0.243
cls_pw: 0.631
Expand All @@ -31,4 +34,5 @@ shear: 0.602
perspective: 0.0
flipud: 0.00856
fliplr: 0.5
mosaic: 1.0
mixup: 0.243
4 changes: 4 additions & 0 deletions data/hyp.scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
warmup_epochs: 3.0 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.1 # warmup initial bias lr
giou: 0.05 # box loss gain
cls: 0.5 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
Expand All @@ -26,4 +29,5 @@ shear: 0.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
18 changes: 11 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def train(hyp, opt, device, tb_writer=None):

# Start training
t0 = time.time()
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
Expand Down Expand Up @@ -250,9 +250,9 @@ def train(hyp, opt, device, tb_writer=None):
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

# Multi-scale
if opt.multi_scale:
Expand Down Expand Up @@ -460,16 +460,19 @@ def train(hyp, opt, device, tb_writer=None):
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
'giou': (1, 0.02, 0.2), # GIoU loss gain
'cls': (1, 0.2, 4.0), # cls loss gain
'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (0, 0.1, 0.7), # IoU training threshold
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
Expand All @@ -481,6 +484,7 @@ def train(hyp, opt, device, tb_writer=None):
'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0)} # image mixup (probability)

assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
Expand All @@ -490,7 +494,7 @@ def train(hyp, opt, device, tb_writer=None):
if opt.bucket:
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists

for _ in range(1): # generations to evolve
for _ in range(300): # generations to evolve
if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
# Select parent(s)
parent = 'single' # parent selection method: 'single' or 'weighted'
Expand All @@ -505,7 +509,7 @@ def train(hyp, opt, device, tb_writer=None):
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination

# Mutate
mp, s = 0.9, 0.2 # mutation probability, sigma
mp, s = 0.8, 0.2 # mutation probability, sigma
npr = np.random
npr.seed(int(time.time()))
g = np.array([x[0] for x in meta.values()]) # gains 0-1
Expand Down
5 changes: 3 additions & 2 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,8 @@ def __getitem__(self, index):
index = self.indices[index]

hyp = self.hyp
if self.mosaic:
mosaic = self.mosaic and random.random() < hyp['mosaic']
if mosaic:
# Load mosaic
img, labels = load_mosaic(self, index)
shapes = None
Expand Down Expand Up @@ -550,7 +551,7 @@ def __getitem__(self, index):

if self.augment:
# Augment imagespace
if not self.mosaic:
if not mosaic:
img, labels = random_perspective(img, labels,
degrees=hyp['degrees'],
translate=hyp['translate'],
Expand Down

0 comments on commit f1c63e2

Please sign in to comment.