Skip to content

Commit

Permalink
smart_optimizer() improved reporting (ultralytics#8887)
Browse files Browse the repository at this point in the history
Update smart_optimizer() weight_decay reporting
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 362d5df commit d2d732b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
1 change: 0 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
nbs = 64 # nominal batch size
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])

# Scheduler
Expand Down
8 changes: 4 additions & 4 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def copy_attr(a, b, include=(), exclude=()):
setattr(a, k, v)


def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-5):
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
Expand All @@ -299,10 +299,10 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-
else:
raise NotImplementedError(f'Optimizer {name} not implemented.')

optimizer.add_param_group({'params': g[0], 'weight_decay': weight_decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
return optimizer


Expand Down

0 comments on commit d2d732b

Please sign in to comment.