diff --git a/train.py b/train.py index ba69de43d255..18c8b531e6f0 100644 --- a/train.py +++ b/train.py @@ -294,7 +294,7 @@ def train(hyp, tb_writer, opt, device): loss, loss_items = compute_loss(pred, targets.to(device), model) # loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices. if local_rank != -1: - loss *= dist.get_world_size() + loss *= opt.world_size if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss_items) return results @@ -449,6 +449,7 @@ def train(hyp, tb_writer, opt, device): opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) opt.total_batch_size = opt.batch_size + opt.world_size = 1 if device.type == 'cpu': mixed_precision = False elif opt.local_rank != -1: @@ -457,9 +458,10 @@ def train(hyp, tb_writer, opt, device): torch.cuda.set_device(opt.local_rank) device = torch.device("cuda", opt.local_rank) dist.init_process_group(backend='nccl', init_method='env://') # distributed backend - - assert opt.batch_size % dist.get_world_size() == 0 - opt.batch_size = opt.total_batch_size // dist.get_world_size() + + opt.world_size = dist.get_world_size() + assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!" + opt.batch_size = opt.total_batch_size // opt.world_size print(opt) # Train diff --git a/utils/datasets.py b/utils/datasets.py index a3a5531f8f54..3d724e6fae62 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -59,7 +59,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa pad=pad) batch_size = min(batch_size, len(dataset)) - nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers + nw = min([os.cpu_count()//opt.world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,