diff --git a/train.py b/train.py index ce211f1f5322..9b5deee5d5aa 100644 --- a/train.py +++ b/train.py @@ -194,6 +194,7 @@ def train(hyp, tb_writer, opt, device): # DDP mode if device.type != 'cpu' and opt.local_rank != -1: # pip install torch==1.4.0+cku100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if mixed_precision: model = DDP(model, delay_allreduce=True) else: