diff --git a/train.py b/train.py index ce211f1f5322..27c83cd7e56a 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data from torch.utils.tensorboard import SummaryWriter +from torch.nn.parallel import DistributedDataParallel as DDP import test # import test.py to get mAP after each epoch from models.yolo import Model @@ -17,9 +18,7 @@ mixed_precision = True try: # Mixed precision training https://github.com/NVIDIA/apex from apex import amp - from apex.parallel import DistributedDataParallel as DDP except: - from torch.nn.parallel import DistributedDataParallel as DDP print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex') mixed_precision = False # not installed @@ -194,10 +193,8 @@ def train(hyp, tb_writer, opt, device): # DDP mode if device.type != 'cpu' and opt.local_rank != -1: # pip install torch==1.4.0+cku100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html - if mixed_precision: - model = DDP(model, delay_allreduce=True) - else: - model = DDP(model, device_ids=[opt.local_rank]) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) # Model parameters hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset @@ -402,6 +399,7 @@ def train(hyp, tb_writer, opt, device): if not opt.evolve: plot_results() # save as results.png print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) + dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None torch.cuda.empty_cache() return results