diff --git a/train.py b/train.py index 27a877157302..0c63b80ae27b 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,13 @@ import argparse +import torch import torch.distributed as dist import torch.nn.functional as F import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP import test # import test.py to get mAP after each epoch from models.yolo import Model @@ -59,11 +61,20 @@ print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma']) -def train(hyp): +def train(hyp, tb_writer, opt, device): epochs = opt.epochs # 300 - batch_size = opt.batch_size # 64 + batch_size = opt.batch_size # batch size per process. + total_batch_size = opt.batch_size if opt.local_rank == -1 else opt.batch_size * torch.distributed.get_world_size() # 64 weights = opt.weights # initial training weights + if opt.local_rank in [-1, 0]: + # TODO: Init DDP logging. Only the first process is allowed to log. + # Since I see lots of print here, the logging is skipped here. + pass + else: + tb_writer = None + + # Configure init_seeds(1) with open(opt.data) as f: @@ -87,8 +98,15 @@ def train(hyp): # Optimizer nbs = 64 # nominal batch size - accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing - hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay + # the default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html + # all-reduce operation is carried out during loss.backward(). + # Thus, there would be redundant all-reduce communications in a accumulation procedure, + # which means, the result is still right but the training speed gets slower. + # TODO: If acceleration is needed, there is an implementation of allreduce_post_accumulation + # in https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/run_pretraining.py + accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing + hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay + pg0, pg1, pg2 = [], [], [] # optimizer parameter groups for k, v in model.named_parameters(): if v.requires_grad: @@ -107,7 +125,8 @@ def train(hyp): del pg0, pg1, pg2 # Load Model - google_utils.attempt_download(weights) + if opt.local_rank in [-1, 0]: + google_utils.attempt_download(weights) start_epoch, best_fitness = 0, 0.0 if weights.endswith('.pt'): # pytorch format ckpt = torch.load(weights, map_location=device) # load checkpoint @@ -153,24 +172,32 @@ def train(hyp): # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 # plot_lr_scheduler(optimizer, scheduler, epochs) - # Initialize distributed training - if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): - dist.init_process_group(backend='nccl', # distributed backend - init_method='tcp://127.0.0.1:9999', # init method - world_size=1, # number of nodes - rank=0) # node rank - model = torch.nn.parallel.DistributedDataParallel(model) - # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html - # Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, - hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect) + hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, local_rank=opt.local_rank) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) # Testloader - testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt, - hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0] + if opt.local_rank in [-1, 0]: + # local_rank is set to 0. Because only the first process is expected to do evaluation. + testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, + hyp=hyp, augment=False, cache=opt.cache_images, rect=True, local_rank=-1)[0] + + # DP mode + if device.type != 'cpu' and opt.local_rank == -1 and torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + # Exponential moving average + # According to https://github.com/rwightman/pytorch-image-models/blob/master/train.py, + # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper + if opt.local_rank in [-1, 0]: + ema = torch_utils.ModelEMA(model) + + # DDP mode + if device.type != 'cpu' and opt.local_rank != -1: + # pip install torch==1.4.0+cku100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html + model = DDP(model, device_ids=[opt.local_rank]) # Model parameters hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset @@ -180,30 +207,29 @@ def train(hyp): model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights # Class frequency - labels = np.concatenate(dataset.labels, 0) - c = torch.tensor(labels[:, 0]) # classes - # cf = torch.bincount(c.long(), minlength=nc) + 1. - # model._initialize_biases(cf.to(device)) if tb_writer: + labels = np.concatenate(dataset.labels, 0) + c = torch.tensor(labels[:, 0]) # classes + # cf = torch.bincount(c.long(), minlength=nc) + 1. + # model._initialize_biases(cf.to(device)) plot_labels(labels) tb_writer.add_histogram('classes', c, 0) + # Check anchors if not opt.noautoanchor: check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) - # Exponential moving average - ema = torch_utils.ModelEMA(model) - # Start training t0 = time.time() nb = len(dataloader) # number of batches n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations) maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' - print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) - print('Using %g dataloader workers' % dataloader.num_workers) - print('Starting training for %g epochs...' % epochs) + if opt.local_rank in [0, -1]: + print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) + print('Using %g dataloader workers' % dataloader.num_workers) + print('Starting training for %g epochs...' % epochs) # torch.autograd.set_detect_anomaly(True) for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ model.train() @@ -219,8 +245,11 @@ def train(hyp): # dataset.mosaic_border = [b - imgsz, -b] # height, width borders mloss = torch.zeros(4, device=device) # mean losses - print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size')) - pbar = tqdm(enumerate(dataloader), total=nb) # progress bar + if opt.local_rank in [-1, 0]: + print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size')) + pbar = tqdm(enumerate(dataloader), total=nb) # progress bar + else: + pbar = enumerate(dataloader) for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- ni = i + nb * epoch # number integrated batches (since train start) imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 @@ -229,7 +258,7 @@ def train(hyp): if ni <= n_burn: xi = [0, n_burn] # x interp # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) - accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) + accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) @@ -264,93 +293,98 @@ def train(hyp): if ni % accumulate == 0: optimizer.step() optimizer.zero_grad() - ema.update(model) + if opt.local_rank in [-1, 0]: + ema.update(model) # Print mloss = (mloss * i + loss_items) / (i + 1) # update mean losses - mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB) - s = ('%10s' * 2 + '%10.4g' * 6) % ( - '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) - pbar.set_description(s) - - # Plot - if ni < 3: - f = 'train_batch%g.jpg' % ni # filename - result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) - if tb_writer and result is not None: - tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) - # tb_writer.add_graph(model, imgs) # add model to tensorboard + if opt.local_rank in [-1, 0]: + mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB) + s = ('%10s' * 2 + '%10.4g' * 6) % ( + '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) + pbar.set_description(s) + # Plot + if ni < 3: + f = 'train_batch%g.jpg' % ni # filename + result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) + if tb_writer and result is not None: + tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) + # tb_writer.add_graph(model, imgs) # add model to tensorboard # end batch ------------------------------------------------------------------------------------------------ # Scheduler scheduler.step() - # mAP - ema.update_attr(model) - final_epoch = epoch + 1 == epochs - if not opt.notest or final_epoch: # Calculate mAP - results, maps, times = test.test(opt.data, - batch_size=batch_size, - imgsz=imgsz_test, - save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), - model=ema.ema, - single_cls=opt.single_cls, - dataloader=testloader) - - # Write - with open(results_file, 'a') as f: - f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) - if len(opt.name) and opt.bucket: - os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name)) - - # Tensorboard - if tb_writer: - tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', - 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1', - 'val/giou_loss', 'val/obj_loss', 'val/cls_loss'] - for x, tag in zip(list(mloss[:-1]) + list(results), tags): - tb_writer.add_scalar(tag, x, epoch) - - # Update best mAP - fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] - if fi > best_fitness: - best_fitness = fi - - # Save model - save = (not opt.nosave) or (final_epoch and not opt.evolve) - if save: - with open(results_file, 'r') as f: # create checkpoint - ckpt = {'epoch': epoch, - 'best_fitness': best_fitness, - 'training_results': f.read(), - 'model': ema.ema.module if hasattr(model, 'module') else ema.ema, - 'optimizer': None if final_epoch else optimizer.state_dict()} - - # Save last, best and delete - torch.save(ckpt, last) - if (best_fitness == fi) and not final_epoch: - torch.save(ckpt, best) - del ckpt + # Only the first process in DDP mode is allowed to log or save checkpoints. + if opt.local_rank in [-1, 0]: + # mAP + ema.update_attr(model) + final_epoch = epoch + 1 == epochs + if not opt.notest or final_epoch: # Calculate mAP + results, maps, times = test.test(opt.data, + batch_size=total_batch_size, + imgsz=imgsz_test, + save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), + model=ema.ema, + single_cls=opt.single_cls, + dataloader=testloader) + + # Write + with open(results_file, 'a') as f: + f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls) + if len(opt.name) and opt.bucket: + os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name)) + + # Tensorboard + if tb_writer: + tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', + 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1', + 'val/giou_loss', 'val/obj_loss', 'val/cls_loss'] + for x, tag in zip(list(mloss[:-1]) + list(results), tags): + tb_writer.add_scalar(tag, x, epoch) + + # Update best mAP + fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1] + if fi > best_fitness: + best_fitness = fi + + # Save model + save = (not opt.nosave) or (final_epoch and not opt.evolve) + if save: + with open(results_file, 'r') as f: # create checkpoint + ckpt = {'epoch': epoch, + 'best_fitness': best_fitness, + 'training_results': f.read(), + 'model': ema.ema.module if hasattr(ema, 'module') else ema.ema, + 'optimizer': None if final_epoch else optimizer.state_dict()} + + # Save last, best and delete + torch.save(ckpt, last) + if (best_fitness == fi) and not final_epoch: + torch.save(ckpt, best) + del ckpt # end epoch ---------------------------------------------------------------------------------------------------- # end training - n = opt.name - if len(n): - n = '_' + n if not n.isnumeric() else n - fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n - for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]): - if os.path.exists(f1): - os.rename(f1, f2) # rename - ispt = f2.endswith('.pt') # is *.pt - strip_optimizer(f2) if ispt else None # strip optimizer - os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload - - if not opt.evolve: - plot_results() # save as results.png - print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) - dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None + results = None + if opt.local_rank in [-1, 0]: + n = opt.name + if len(n): + n = '_' + n if not n.isnumeric() else n + fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n + for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]): + if os.path.exists(f1): + os.rename(f1, f2) # rename + ispt = f2.endswith('.pt') # is *.pt + strip_optimizer(f2) if ispt else None # strip optimizer + os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload + if not opt.evolve: + plot_results() # save as results.png + print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) + if opt.local_rank == -1: + dist.destroy_process_group() torch.cuda.empty_cache() return results @@ -359,7 +393,7 @@ def train(hyp): check_git_status() parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=300) - parser.add_argument('--batch-size', type=int, default=16) + parser.add_argument('--batch-size', type=int, default=16, help="batch size for all gpus.") parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='*.cfg path') parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') @@ -377,24 +411,40 @@ def train(hyp): parser.add_argument('--adam', action='store_true', help='use adam optimizer') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%') parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') + # Parameter For DDP. + parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.") opt = parser.parse_args() opt.weights = last if opt.resume else opt.weights opt.cfg = check_file(opt.cfg) # check file opt.data = check_file(opt.data) # check file print(opt) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) + # If local_rank is not -1, the DDP mode is triggered. Use local_rank to overwrite the opt.device config. device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) if device.type == 'cpu': mixed_precision = False + elif opt.local_rank != -1: + assert torch.cuda.device_count() > opt.local_rank + torch.cuda.set_device(opt.local_rank) + device = torch.device("cuda") + dist.init_process_group(backend='nccl', init_method='env://') # distributed backend + + assert opt.batch_size % torch.distributed.get_world_size() == 0 + opt.batch_size = opt.batch_size // torch.distributed.get_world_size() # Train if not opt.evolve: - tb_writer = SummaryWriter(comment=opt.name) + if opt.local_rank in [-1, 0]: + tb_writer = SummaryWriter(comment=opt.name) + else: + tb_writer = None print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') - train(hyp) + train(hyp, tb_writer, opt, device) # Evolve hyperparameters (optional) else: + assert opt.local_rank == -1, "DDP mode currently not implemented for Evolve!" + tb_writer = None opt.notest, opt.nosave = True, True # only test/save final epoch if opt.bucket: @@ -433,7 +483,7 @@ def train(hyp): hyp[k] = np.clip(hyp[k], v[0], v[1]) # Train mutation - results = train(hyp.copy()) + results = train(hyp.copy(), tb_writer, opt, device) # Write mutation results print_mutation(hyp, results, opt.bucket) diff --git a/utils/datasets.py b/utils/datasets.py index 1ebd709482fe..b01201f46dd3 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -14,7 +14,7 @@ from torch.utils.data import Dataset from tqdm import tqdm -from utils.utils import xyxy2xywh, xywh2xyxy +from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng'] @@ -41,21 +41,25 @@ def exif_size(img): return s -def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False): - dataset = LoadImagesAndLabels(path, imgsz, batch_size, - augment=augment, # augment images - hyp=hyp, # augmentation hyperparameters - rect=rect, # rectangular training - cache_images=cache, - single_cls=opt.single_cls, - stride=stride, - pad=pad) +def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, local_rank=-1): + # Make sure only the first process in DDP process the dataset first, and the following others can use the cache. + with torch_distributed_zero_first(local_rank): + dataset = LoadImagesAndLabels(path, imgsz, batch_size, + augment=augment, # augment images + hyp=hyp, # augmentation hyperparameters + rect=rect, # rectangular training + cache_images=cache, + single_cls=opt.single_cls, + stride=stride, + pad=pad) batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=nw, + sampler=train_sampler, pin_memory=True, collate_fn=LoadImagesAndLabels.collate_fn) return dataloader, dataset diff --git a/utils/torch_utils.py b/utils/torch_utils.py index e069792e6e3f..a2f69c1a92cb 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -185,10 +185,8 @@ def update(self, model): self.updates += 1 d = self.decay(self.updates) with torch.no_grad(): - if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): - msd, esd = model.module.state_dict(), self.ema.module.state_dict() - else: - msd, esd = model.state_dict(), self.ema.state_dict() + msd = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() + esd = self.ema.module.state_dict() if hasattr(self.ema, 'module') else self.ema.state_dict() for k, v in esd.items(): if v.dtype.is_floating_point: @@ -198,5 +196,6 @@ def update(self, model): def update_attr(self, model): # Assign attributes (which may change during training) for k in model.__dict__.keys(): - if not k.startswith('_'): + if not k.startswith('_') and not isinstance(getattr(model, k), + (torch.distributed.ProcessGroupNCCL, torch.distributed.Reducer)): setattr(self.ema, k, getattr(model, k)) diff --git a/utils/utils.py b/utils/utils.py index 305486a5f6a3..eb66a60dc7e5 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -8,6 +8,7 @@ from copy import copy from pathlib import Path from sys import platform +from contextlib import contextmanager import cv2 import matplotlib @@ -31,6 +32,19 @@ cv2.setNumThreads(0) +@contextmanager +def torch_distributed_zero_first(local_rank: int): + """ + Decorator to make all processes in distributed training wait for each local_master to do something. + """ + if local_rank not in [-1, 0]: + torch.distributed.barrier() + yield + if local_rank == 0: + torch.distributed.barrier() + + + def init_seeds(seed=0): random.seed(seed) np.random.seed(seed)