Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update train.py #6

Merged
merged 1 commit into from
Jul 14, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def train(hyp, tb_writer, opt, device):

# Load Model
# Avoid multiple downloads.
with torch_distributed_zero_first(opt.local_rank):
with torch_distributed_zero_first(local_rank):
google_utils.attempt_download(weights)
start_epoch, best_fitness = 0, 0.0
if weights.endswith('.pt'): # pytorch format
Expand All @@ -137,7 +137,7 @@ def train(hyp, tb_writer, opt, device):
except KeyError as e:
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
"Please delete or update %s and try again, or use --weights '' to train from scratch." \
% (opt.weights, opt.cfg, opt.weights, opt.weights)
% (weights, opt.cfg, weights, weights)
raise KeyError(s) from e

# load optimizer
Expand All @@ -154,7 +154,7 @@ def train(hyp, tb_writer, opt, device):
start_epoch = ckpt['epoch'] + 1
if epochs < start_epoch:
print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
(opt.weights, ckpt['epoch'], epochs))
(weights, ckpt['epoch'], epochs))
epochs += ckpt['epoch'] # finetune additional epochs

del ckpt
Expand All @@ -170,30 +170,30 @@ def train(hyp, tb_writer, opt, device):
# plot_lr_scheduler(optimizer, scheduler, epochs)

# DP mode
if device.type != 'cpu' and opt.local_rank == -1 and torch.cuda.device_count() > 1:
if device.type != 'cpu' and local_rank == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

# Exponential moving average
# From https://github.com/rwightman/pytorch-image-models/blob/master/train.py:
# "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper"
# chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules.
if device.type != 'cpu' and opt.local_rank != -1:
if device.type != 'cpu' and local_rank != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
ema = torch_utils.ModelEMA(model) if opt.local_rank in [-1, 0] else None
ema = torch_utils.ModelEMA(model) if local_rank in [-1, 0] else None

# DDP mode
if device.type != 'cpu' and opt.local_rank != -1:
if device.type != 'cpu' and local_rank != -1:
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, local_rank=opt.local_rank)
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, local_rank=local_rank)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)

# Testloader
if opt.local_rank in [-1, 0]:
if local_rank in [-1, 0]:
# local_rank is set to -1. Because only the first process is expected to do evaluation.
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
hyp=hyp, augment=False, cache=opt.cache_images, rect=True, local_rank=-1)[0]
Expand Down Expand Up @@ -226,7 +226,7 @@ def train(hyp, tb_writer, opt, device):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler.last_epoch = start_epoch - 1 # do not move
if opt.local_rank in [0, -1]:
if local_rank in [0, -1]:
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % dataloader.num_workers)
print('Starting training for %g epochs...' % epochs)
Expand Down Expand Up @@ -256,9 +256,9 @@ def train(hyp, tb_writer, opt, device):
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders

mloss = torch.zeros(4, device=device) # mean losses
if opt.local_rank != -1:
if local_rank != -1:
dataloader.sampler.set_epoch(epoch)
if opt.local_rank in [-1, 0]:
if local_rank in [-1, 0]:
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
else:
Expand Down Expand Up @@ -293,7 +293,7 @@ def train(hyp, tb_writer, opt, device):
# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model)
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
if opt.local_rank != -1:
if local_rank != -1:
loss *= dist.get_world_size()
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
Expand All @@ -314,7 +314,7 @@ def train(hyp, tb_writer, opt, device):
ema.update(model)

# Print
if opt.local_rank in [-1, 0]:
if local_rank in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
Expand All @@ -335,7 +335,7 @@ def train(hyp, tb_writer, opt, device):
scheduler.step()

# Only the first process in DDP mode is allowed to log or save checkpoints.
if opt.local_rank in [-1, 0]:
if local_rank in [-1, 0]:
# mAP
if ema is not None:
ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride'])
Expand Down Expand Up @@ -387,7 +387,7 @@ def train(hyp, tb_writer, opt, device):
# end epoch ----------------------------------------------------------------------------------------------------
# end training

if opt.local_rank in [-1, 0]:
if local_rank in [-1, 0]:
# Strip optimizers
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
Expand Down