diff --git a/train.py b/train.py index 18c8b531e6f0..f42479094ea5 100644 --- a/train.py +++ b/train.py @@ -413,7 +413,7 @@ def train(hyp, tb_writer, opt, device): parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)') parser.add_argument('--epochs', type=int, default=300) - parser.add_argument('--batch-size', type=int, default=16, help="batch size for all gpus.") + parser.add_argument('--batch-size', type=int, default=16, help="Total batch size for all gpus.") parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--resume', nargs='?', const='get_last', default=False, @@ -460,6 +460,8 @@ def train(hyp, tb_writer, opt, device): dist.init_process_group(backend='nccl', init_method='env://') # distributed backend opt.world_size = dist.get_world_size() + assert opt.world_size <= 2, \ + "DDP mode with > 2 gpus will suffer from performance deterioration. The reason remains unknown!" assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!" opt.batch_size = opt.total_batch_size // opt.world_size print(opt)