Skip to content

Commit

Permalink
Refactor optimizer initialization (ultralytics#8607)
Browse files Browse the repository at this point in the history
* Refactor optimizer initialization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Update train.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and Clay Januhowski committed Sep 8, 2022
1 parent 27bcc61 commit 777c146
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
29 changes: 4 additions & 25 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.nn as nn
import yaml
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD, Adam, AdamW, lr_scheduler
from torch.optim import lr_scheduler
from tqdm import tqdm

FILE = Path(__file__).resolve()
Expand All @@ -54,7 +54,8 @@
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer,
torch_distributed_zero_first)

LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
Expand Down Expand Up @@ -149,29 +150,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")

g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
for v in model.modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
g[2].append(v.bias)
if isinstance(v, bn): # weight (no decay)
g[1].append(v.weight)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
g[0].append(v.weight)

if opt.optimizer == 'Adam':
optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
elif opt.optimizer == 'AdamW':
optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999), weight_decay=0.0)
else:
optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
del g
optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])

# Scheduler
if opt.cos_lr:
Expand Down
32 changes: 31 additions & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn
import torch.nn.functional as F

from utils.general import LOGGER, file_date, git_describe
from utils.general import LOGGER, colorstr, file_date, git_describe

try:
import thop # for FLOPs computation
Expand Down Expand Up @@ -260,6 +260,36 @@ def copy_attr(a, b, include=(), exclude=()):
setattr(a, k, v)


def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-5):
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
for v in model.modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
g[2].append(v.bias)
if isinstance(v, bn): # weight (no decay)
g[1].append(v.weight)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
g[0].append(v.weight)

if name == 'Adam':
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
elif name == 'AdamW':
optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == 'RMSProp':
optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
elif name == 'SGD':
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else:
raise NotImplementedError(f'Optimizer {name} not implemented.')

optimizer.add_param_group({'params': g[0], 'weight_decay': weight_decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
return optimizer


class EarlyStopping:
# YOLOv5 simple early stopper
def __init__(self, patience=30):
Expand Down

0 comments on commit 777c146

Please sign in to comment.