From d738487089e41c22b3b1cd73aa7c1c40320a6ebf Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 14 Jul 2020 17:33:38 +0700 Subject: [PATCH 1/4] Adding world_size Reduce calls to torch.distributed. For use in create_dataloader. --- train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index ba69de43d255..2eb0a76ef137 100644 --- a/train.py +++ b/train.py @@ -294,7 +294,7 @@ def train(hyp, tb_writer, opt, device): loss, loss_items = compute_loss(pred, targets.to(device), model) # loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices. if local_rank != -1: - loss *= dist.get_world_size() + loss *= opt.world_size if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss_items) return results @@ -451,15 +451,17 @@ def train(hyp, tb_writer, opt, device): opt.total_batch_size = opt.batch_size if device.type == 'cpu': mixed_precision = False + opt.world_size = 1 elif opt.local_rank != -1: # DDP mode assert torch.cuda.device_count() > opt.local_rank torch.cuda.set_device(opt.local_rank) device = torch.device("cuda", opt.local_rank) dist.init_process_group(backend='nccl', init_method='env://') # distributed backend - - assert opt.batch_size % dist.get_world_size() == 0 - opt.batch_size = opt.total_batch_size // dist.get_world_size() + + opt.world_size = dist.get_world_size() + assert opt.batch_size % opt.world_size == 0 + opt.batch_size = opt.total_batch_size // opt.world_size print(opt) # Train From 69364d6050e048d0d8834e0f30ce84da3f6a13f3 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 14 Jul 2020 17:36:48 +0700 Subject: [PATCH 2/4] Changed number of workers check --- utils/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/datasets.py b/utils/datasets.py index a3a5531f8f54..3d724e6fae62 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -59,7 +59,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa pad=pad) batch_size = min(batch_size, len(dataset)) - nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers + nw = min([os.cpu_count()//opt.world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, From 63648925288d63a21174a4dd28f92dbfebfeb75a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 14 Jul 2020 19:16:15 +0700 Subject: [PATCH 3/4] Add assert message for clarification Clarify why assertion was thrown to users --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 2eb0a76ef137..7494af2754b5 100644 --- a/train.py +++ b/train.py @@ -460,7 +460,7 @@ def train(hyp, tb_writer, opt, device): dist.init_process_group(backend='nccl', init_method='env://') # distributed backend opt.world_size = dist.get_world_size() - assert opt.batch_size % opt.world_size == 0 + assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!" opt.batch_size = opt.total_batch_size // opt.world_size print(opt) From 787582f97251834f955ef05a77072b8c673a8397 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 14 Jul 2020 20:38:58 +0700 Subject: [PATCH 4/4] Fixed issue with single gpu not having world_size --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 7494af2754b5..18c8b531e6f0 100644 --- a/train.py +++ b/train.py @@ -449,9 +449,9 @@ def train(hyp, tb_writer, opt, device): opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) opt.total_batch_size = opt.batch_size + opt.world_size = 1 if device.type == 'cpu': mixed_precision = False - opt.world_size = 1 elif opt.local_rank != -1: # DDP mode assert torch.cuda.device_count() > opt.local_rank