Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update train.py #4136

Merged
merged 5 commits into from
Jul 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 47 additions & 57 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@

import math
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam, SGD, lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

Expand Down Expand Up @@ -58,16 +56,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
device,
):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, = \
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.noval, opt.nosave, opt.workers

# Directories
save_dir = Path(save_dir)
wdir = save_dir / 'weights'
wdir.mkdir(parents=True, exist_ok=True) # make dir
last = wdir / 'last.pt'
best = wdir / 'best.pt'
results_file = save_dir / 'results.txt'
w = save_dir / 'weights' # weights dir
w.mkdir(parents=True, exist_ok=True) # make dir
last, best, results_file = w / 'last.pt', w / 'best.pt', save_dir / 'results.txt'

# Hyperparameters
if isinstance(hyp, str):
Expand All @@ -92,7 +87,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
loggers = {'wandb': None, 'tb': None} # loggers dict
if RANK in [-1, 0]:
# TensorBoard
if not evolve:
if plots:
prefix = colorstr('tensorboard: ')
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(str(save_dir))
Expand All @@ -105,11 +100,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
loggers['wandb'] = wandb_logger.wandb
if loggers['wandb']:
data_dict = wandb_logger.data_dict
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update values if resuming

nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset

# Model
Expand All @@ -120,23 +115,22 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(state_dict, strict=False) # load
LOGGER.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(csd, strict=False) # load
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check
train_path = data_dict['train']
val_path = data_dict['val']
train_path, val_path = data_dict['train'], data_dict['val']

# Freeze
freeze = [] # parameter names to freeze (full or partial)
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
if any(x in k for x in freeze):
print('freezing %s' % k)
print(f'freezing {k}')
v.requires_grad = False

# Optimizer
Expand All @@ -145,33 +139,32 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")

pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
for k, v in model.named_modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
pg2.append(v.bias) # biases
if isinstance(v, nn.BatchNorm2d):
pg0.append(v.weight) # no decay
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
pg1.append(v.weight) # apply decay
g0, g1, g2 = [], [], [] # optimizer parameter groups
for v in model.modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
g2.append(v.bias)
if isinstance(v, nn.BatchNorm2d): # weight with decay
g0.append(v.weight)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight without decay
g1.append(v.weight)

if opt.adam:
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
else:
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
LOGGER.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
del pg0, pg1, pg2
optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
optimizer.add_param_group({'params': g2}) # add g2 (biases)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
del g0, g1, g2

# Scheduler https://arxiv.org/pdf/1812.01187.pdf
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
# Scheduler
if opt.linear_lr:
lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
else:
lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# plot_lr_scheduler(optimizer, scheduler, epochs)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)

# EMA
ema = ModelEMA(model) if RANK in [-1, 0] else None
Expand All @@ -196,13 +189,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Epochs
start_epoch = ckpt['epoch'] + 1
if resume:
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
if epochs < start_epoch:
LOGGER.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
(weights, ckpt['epoch'], epochs))
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
epochs += ckpt['epoch'] # finetune additional epochs

del ckpt, state_dict
del ckpt, csd

# Image sizes
gs = max(int(model.stride.max()), 32) # grid size (max stride)
Expand All @@ -217,7 +209,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# SyncBatchNorm
if opt.sync_bn and cuda and RANK != -1:
raise Exception('can not train with --sync-bn, known issue https://github.com/ultralytics/yolov5/issues/3998')
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
LOGGER.info('Using SyncBatchNorm()')

Expand All @@ -228,7 +219,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(train_loader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'

# Process 0
if RANK in [-1, 0]:
Expand Down Expand Up @@ -261,7 +252,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
hyp['label_smoothing'] = opt.label_smoothing
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names

Expand Down Expand Up @@ -315,7 +305,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Warmup
if ni <= nw:
xi = [0, nw] # x interp
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
# compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
Expand All @@ -329,7 +319,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1:
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Forward
with amp.autocast(enabled=cuda):
Expand All @@ -355,7 +345,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Print
if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s)
Expand All @@ -381,7 +371,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# DDP process 0 or single-GPU
if RANK in [-1, 0]:
# mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
Expand Down Expand Up @@ -457,6 +447,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
model=attempt_load(m, device).half(),
iou_thres=0.7, # NMS IoU threshold for best pycocotools results
single_cls=single_cls,
dataloader=val_loader,
save_dir=save_dir,
Expand Down Expand Up @@ -525,16 +516,14 @@ def main(opt):
check_requirements(exclude=['thop'])

# Resume
wandb_run = check_wandb_resume(opt)
if opt.resume and not wandb_run: # resume an interrupted run
if opt.resume and not check_wandb_resume(opt): # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
LOGGER.info(f'Resuming training from {ckpt}')
else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
opt.name = 'evolve' if opt.evolve else opt.name
Expand All @@ -545,11 +534,13 @@ def main(opt):
if LOCAL_RANK != -1:
from datetime import timedelta
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
assert not opt.evolve, '--evolve argument is not compatible with DDP training'
assert not opt.sync_bn, '--sync-bn known training issue, see https://github.com/ultralytics/yolov5/issues/3998'
torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60))
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'

# Train
if not opt.evolve:
Expand Down Expand Up @@ -594,7 +585,6 @@ def main(opt):
hyp = yaml.safe_load(f) # load hyps dict
if 'anchors' not in hyp: # anchors commented in hyp.yaml
hyp['anchors'] = 3
assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
opt.noval, opt.nosave = True, True # only val/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
Expand Down Expand Up @@ -646,7 +636,7 @@ def main(opt):


def run(**kwargs):
# Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
# Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
opt = parse_opt(True)
for k, v in kwargs.items():
setattr(opt, k, v)
Expand Down
2 changes: 1 addition & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def clean_str(s):


def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1


Expand Down
2 changes: 1 addition & 1 deletion utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, model, autobalance=False):
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
for k in 'na', 'nc', 'nl', 'anchors':
setattr(self, k, getattr(det, k))

Expand Down