From d738487089e41c22b3b1cd73aa7c1c40320a6ebf Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 14 Jul 2020 17:33:38 +0700 Subject: [PATCH] Adding world_size Reduce calls to torch.distributed. For use in create_dataloader. --- train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index ba69de43d255..2eb0a76ef137 100644 --- a/train.py +++ b/train.py @@ -294,7 +294,7 @@ def train(hyp, tb_writer, opt, device): loss, loss_items = compute_loss(pred, targets.to(device), model) # loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices. if local_rank != -1: - loss *= dist.get_world_size() + loss *= opt.world_size if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss_items) return results @@ -451,15 +451,17 @@ def train(hyp, tb_writer, opt, device): opt.total_batch_size = opt.batch_size if device.type == 'cpu': mixed_precision = False + opt.world_size = 1 elif opt.local_rank != -1: # DDP mode assert torch.cuda.device_count() > opt.local_rank torch.cuda.set_device(opt.local_rank) device = torch.device("cuda", opt.local_rank) dist.init_process_group(backend='nccl', init_method='env://') # distributed backend - - assert opt.batch_size % dist.get_world_size() == 0 - opt.batch_size = opt.total_batch_size // dist.get_world_size() + + opt.world_size = dist.get_world_size() + assert opt.batch_size % opt.world_size == 0 + opt.batch_size = opt.total_batch_size // opt.world_size print(opt) # Train