From 777c1467237fa6efcd5bb3f148e34605f1c85ffd Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 17 Jul 2022 11:43:52 +0200 Subject: [PATCH] Refactor optimizer initialization (#8607) * Refactor optimizer initialization * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train.py * Update train.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- train.py | 29 ++++------------------------- utils/torch_utils.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/train.py b/train.py index ff13f1e256ec..6b463bf56423 100644 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ import torch.nn as nn import yaml from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import SGD, Adam, AdamW, lr_scheduler +from torch.optim import lr_scheduler from tqdm import tqdm FILE = Path(__file__).resolve() @@ -54,7 +54,8 @@ from utils.loss import ComputeLoss from utils.metrics import fitness from utils.plots import plot_evolve, plot_labels -from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first +from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer, + torch_distributed_zero_first) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) @@ -149,29 +150,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}") - - g = [], [], [] # optimizer parameter groups - bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() - for v in model.modules(): - if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias - g[2].append(v.bias) - if isinstance(v, bn): # weight (no decay) - g[1].append(v.weight) - elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay) - g[0].append(v.weight) - - if opt.optimizer == 'Adam': - optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum - elif opt.optimizer == 'AdamW': - optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999), weight_decay=0.0) - else: - optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) - - optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay - optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) - LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups " - f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias") - del g + optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay']) # Scheduler if opt.cos_lr: diff --git a/utils/torch_utils.py b/utils/torch_utils.py index c21dc6658c1e..d82368dc6271 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -18,7 +18,7 @@ import torch.nn as nn import torch.nn.functional as F -from utils.general import LOGGER, file_date, git_describe +from utils.general import LOGGER, colorstr, file_date, git_describe try: import thop # for FLOPs computation @@ -260,6 +260,36 @@ def copy_attr(a, b, include=(), exclude=()): setattr(a, k, v) +def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-5): + # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay + g = [], [], [] # optimizer parameter groups + bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() + for v in model.modules(): + if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay) + g[2].append(v.bias) + if isinstance(v, bn): # weight (no decay) + g[1].append(v.weight) + elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay) + g[0].append(v.weight) + + if name == 'Adam': + optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum + elif name == 'AdamW': + optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) + elif name == 'RMSProp': + optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum) + elif name == 'SGD': + optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) + else: + raise NotImplementedError(f'Optimizer {name} not implemented.') + + optimizer.add_param_group({'params': g[0], 'weight_decay': weight_decay}) # add g0 with weight_decay + optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) + LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups " + f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias") + return optimizer + + class EarlyStopping: # YOLOv5 simple early stopper def __init__(self, patience=30):