diff --git a/train.py b/train.py index ecac59857ccc..6bd65f063391 100644 --- a/train.py +++ b/train.py @@ -181,10 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) logger.info('Using SyncBatchNorm()') - # DDP mode - if cuda and rank != -1: - model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) - # Trainloader dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, @@ -214,7 +210,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Anchors if not opt.noautoanchor: check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) - model.half().float() # pre-reduce anchor precision + model.half().float() # pre-reduce anchor precision + + # DDP mode + if cuda and rank != -1: + model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) # Model parameters hyp['box'] *= 3. / nl # scale to layers