From a93e484820b386a37b298edf6599bddf33c8456f Mon Sep 17 00:00:00 2001 From: yellowdolphin <42343818+yellowdolphin@users.noreply.github.com> Date: Tue, 22 Jun 2021 03:10:00 +0200 Subject: [PATCH 1/6] gradient accumulation during warmup in train.py Context: `accumulate` is the number of batches/gradients accumulated before calling the next optimizer.step(). During warmup, it is ramped up from 1 to the final value nbs / batch_size. Although I have not seen this in other libraries, I like the idea. During warmup, as grads are large, too large steps are more of on issue than gradient noise due to small steps. The bug: The condition to perform the opt step is wrong > if ni % accumulate == 0: This produces irregular step sizes if `accumulate` is not constant. It becomes relevant when batch_size is small and `accumulate` changes many times during warmup. This demo also shows the proposed solution, to use a ">=" condition instead: https://colab.research.google.com/drive/1MA2z2eCXYB_BC5UZqgXueqL_y1Tz_XVq?usp=sharing Further, I propose not to restrict the number of warmup iterations to >= 1000. If the user changes hyp['warmup_epochs'], this causes unexpected behavior. Also, it makes evolution unstable if this parameter was to be optimized. --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 9ac12b12aacf..55a16141d2b3 100644 --- a/train.py +++ b/train.py @@ -268,8 +268,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Start training t0 = time.time() - nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) + nw = round(hyp['warmup_epochs'] * nb) # number of warmup iterations # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training + last_opt_step = -1 maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move @@ -344,12 +345,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary scaler.scale(loss).backward() # Optimize - if ni % accumulate == 0: + if ni - last_opt_step >= accumulate: scaler.step(optimizer) # optimizer.step scaler.update() optimizer.zero_grad() if ema: ema.update(model) + last_opt_step = ni # Print if RANK in [-1, 0]: From ce33f84d6756725e8fa067c43185b54ad0d7e79f Mon Sep 17 00:00:00 2001 From: greendolphin Date: Wed, 23 Jun 2021 20:35:27 +0200 Subject: [PATCH 2/6] replace last_opt_step tracking by do_step(ni) --- train.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 55a16141d2b3..65144dc90dc4 100644 --- a/train.py +++ b/train.py @@ -53,6 +53,28 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) +def get_step_condition(nw, accumulate): + nw, accumulate = int(nw), int(accumulate) + + def acc(ni): + return int(1 + (accumulate - 1) / nw * ni) + + warmup_steps, ni = [], nw - 1 + while ni >= 0: + warmup_steps.append(ni) + ni -= acc(ni) + warmup_steps = set(warmup_steps) + + def step_condition(ni): + "Return whether optimizer.step() is to be called in iteration `ni`" + nonlocal nw, accumulate, warmup_steps + i = ni - nw + if (i >= 0) and ((i + 1) % accumulate == 0): return True + return ni in warmup_steps + + return step_condition + + def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, @@ -141,7 +163,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Optimizer nbs = 64 # nominal batch size - accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing + accumulate = max(int(round(nbs / batch_size)), 1) # accumulate loss before optimizing + nw = round(hyp['warmup_epochs'] * nb) # number of warmup iterations + do_step = get_step_condition(nw, accumulate) hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") @@ -268,9 +292,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Start training t0 = time.time() - nw = round(hyp['warmup_epochs'] * nb) # number of warmup iterations - # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training - last_opt_step = -1 maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move @@ -317,7 +338,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if ni <= nw: xi = [0, nw] # x interp # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) - accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) @@ -345,13 +365,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary scaler.scale(loss).backward() # Optimize - if ni - last_opt_step >= accumulate: + if do_step(ni): scaler.step(optimizer) # optimizer.step scaler.update() optimizer.zero_grad() if ema: ema.update(model) - last_opt_step = ni # Print if RANK in [-1, 0]: From b7c489f06221df0680f906114e45ba96ed71bbc7 Mon Sep 17 00:00:00 2001 From: greendolphin Date: Wed, 23 Jun 2021 21:01:31 +0200 Subject: [PATCH 3/6] add docstrings --- train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train.py b/train.py index 65144dc90dc4..a0eec21729ef 100644 --- a/train.py +++ b/train.py @@ -54,9 +54,11 @@ def get_step_condition(nw, accumulate): + "Supports gradient accumulation and warmup for `nw` iterations" nw, accumulate = int(nw), int(accumulate) def acc(ni): + "Ramp up number of accumulated grads from 1 to `accumulate`" return int(1 + (accumulate - 1) / nw * ni) warmup_steps, ni = [], nw - 1 From dd5a57459779b7da57a858df763416fe18931854 Mon Sep 17 00:00:00 2001 From: greendolphin Date: Thu, 24 Jun 2021 02:07:21 +0200 Subject: [PATCH 4/6] move down nw --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index a0eec21729ef..58bbc527928d 100644 --- a/train.py +++ b/train.py @@ -166,8 +166,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Optimizer nbs = 64 # nominal batch size accumulate = max(int(round(nbs / batch_size)), 1) # accumulate loss before optimizing - nw = round(hyp['warmup_epochs'] * nb) # number of warmup iterations - do_step = get_step_condition(nw, accumulate) hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") @@ -294,6 +292,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Start training t0 = time.time() + nw = round(hyp['warmup_epochs'] * nb) # number of warmup iterations + do_step = get_step_condition(nw, accumulate) maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move From 53aae7fd85093c92ed94e544cc7d6a6c89ee106d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 12:04:42 +0200 Subject: [PATCH 5/6] Update train.py --- train.py | 37 ++++++++----------------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/train.py b/train.py index 58bbc527928d..01429c0e9b39 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,6 @@ import argparse import logging -import math import os import random import sys @@ -16,6 +15,7 @@ from pathlib import Path from threading import Thread +import math import numpy as np import torch.distributed as dist import torch.nn as nn @@ -53,30 +53,6 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) -def get_step_condition(nw, accumulate): - "Supports gradient accumulation and warmup for `nw` iterations" - nw, accumulate = int(nw), int(accumulate) - - def acc(ni): - "Ramp up number of accumulated grads from 1 to `accumulate`" - return int(1 + (accumulate - 1) / nw * ni) - - warmup_steps, ni = [], nw - 1 - while ni >= 0: - warmup_steps.append(ni) - ni -= acc(ni) - warmup_steps = set(warmup_steps) - - def step_condition(ni): - "Return whether optimizer.step() is to be called in iteration `ni`" - nonlocal nw, accumulate, warmup_steps - i = ni - nw - if (i >= 0) and ((i + 1) % accumulate == 0): return True - return ni in warmup_steps - - return step_condition - - def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, @@ -165,7 +141,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Optimizer nbs = 64 # nominal batch size - accumulate = max(int(round(nbs / batch_size)), 1) # accumulate loss before optimizing + accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") @@ -292,8 +268,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Start training t0 = time.time() - nw = round(hyp['warmup_epochs'] * nb) # number of warmup iterations - do_step = get_step_condition(nw, accumulate) + nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) + # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training + last_opt_step = -1 maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move @@ -340,6 +317,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if ni <= nw: xi = [0, nw] # x interp # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) + accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) @@ -367,12 +345,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary scaler.scale(loss).backward() # Optimize - if do_step(ni): + if ni - last_opt_step >= accumulate: scaler.step(optimizer) # optimizer.step scaler.update() optimizer.zero_grad() if ema: ema.update(model) + last_opt_step = ni # Print if RANK in [-1, 0]: From 9ac476ec0c81155037b094c3ad24721a2e024828 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 12:15:17 +0200 Subject: [PATCH 6/6] revert math import move --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 01429c0e9b39..257be065f641 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,7 @@ import argparse import logging +import math import os import random import sys @@ -15,7 +16,6 @@ from pathlib import Path from threading import Thread -import math import numpy as np import torch.distributed as dist import torch.nn as nn