diff --git a/train.py b/train.py index ce211f1f5322..f05b38d30937 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP import test # import test.py to get mAP after each epoch from models.yolo import Model @@ -17,9 +18,7 @@ mixed_precision = True try: # Mixed precision training https://github.com/NVIDIA/apex from apex import amp - from apex.parallel import DistributedDataParallel as DDP except: - from torch.nn.parallel import DistributedDataParallel as DDP print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex') mixed_precision = False # not installed @@ -170,6 +169,22 @@ def train(hyp, tb_writer, opt, device): # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 # plot_lr_scheduler(optimizer, scheduler, epochs) + # DP mode + if device.type != 'cpu' and opt.local_rank == -1 and torch.cuda.device_count() > 1: + model = torch.nn.DataParallel(model) + + # Exponential moving average + # From https://github.com/rwightman/pytorch-image-models/blob/master/train.py: + # "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper" + # chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules. + if device.type != 'cpu' and opt.local_rank != -1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) + ema = torch_utils.ModelEMA(model) if opt.local_rank in [-1, 0] else None + + # DDP mode + if device.type != 'cpu' and opt.local_rank != -1: + model = DDP(model, device_ids=[local_rank], output_device=local_rank) + # Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, local_rank=opt.local_rank) @@ -182,23 +197,6 @@ def train(hyp, tb_writer, opt, device): testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False, cache=opt.cache_images, rect=True, local_rank=-1)[0] - # DP mode - if device.type != 'cpu' and opt.local_rank == -1 and torch.cuda.device_count() > 1: - model = torch.nn.DataParallel(model) - - # Exponential moving average - # According to https://github.com/rwightman/pytorch-image-models/blob/master/train.py, - # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper - ema = torch_utils.ModelEMA(model) if opt.local_rank in [-1, 0] else None - - # DDP mode - if device.type != 'cpu' and opt.local_rank != -1: - # pip install torch==1.4.0+cku100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html - if mixed_precision: - model = DDP(model, delay_allreduce=True) - else: - model = DDP(model, device_ids=[opt.local_rank]) - # Model parameters hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset model.nc = nc # attach number of classes to model @@ -208,7 +206,8 @@ def train(hyp, tb_writer, opt, device): model.names = data_dict['names'] # Class frequency - if tb_writer: + # TODO: + if 0: #tb_writer: labels = np.concatenate(dataset.labels, 0) c = torch.tensor(labels[:, 0]) # classes # cf = torch.bincount(c.long(), minlength=nc) + 1. @@ -251,7 +250,7 @@ def train(hyp, tb_writer, opt, device): dist.broadcast(indices, 0) if local_rank != 0: dataset.indices = indices.cpu().numpy() - + # Update mosaic border # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) # dataset.mosaic_border = [b - imgsz, -b] # height, width borders @@ -264,21 +263,21 @@ def train(hyp, tb_writer, opt, device): pbar = tqdm(enumerate(dataloader), total=nb) # progress bar else: pbar = enumerate(dataloader) + optimizer.zero_grad() for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- ni = i + nb * epoch # number integrated batches (since train start) imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 # Burn-in if ni <= n_burn: - ni_burned = ni xi = [0, n_burn] # x interp - # model.gr = np.interp(ni_burned, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) - accumulate = max(1, np.interp(ni_burned, xi, [1, nbs / total_batch_size]).round()) + # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) + accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 - x['lr'] = np.interp(ni_burned, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) + x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) if 'momentum' in x: - x['momentum'] = np.interp(ni_burned, xi, [0.9, hyp['momentum']]) + x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']]) # Multi-scale if opt.multi_scale: @@ -310,10 +309,9 @@ def train(hyp, tb_writer, opt, device): # Optimize if ni % accumulate == 0: optimizer.step() - torch.cuda.synchronize() + optimizer.zero_grad() if ema is not None: ema.update(model) - optimizer.zero_grad() # Print if opt.local_rank in [-1, 0]: @@ -443,7 +441,7 @@ def train(hyp, tb_writer, opt, device): # DDP mode assert torch.cuda.device_count() > opt.local_rank torch.cuda.set_device(opt.local_rank) - device = torch.device("cuda") + device = torch.device("cuda", opt.local_rank) dist.init_process_group(backend='nccl', init_method='env://') # distributed backend assert opt.batch_size % dist.get_world_size() == 0