diff --git a/train.py b/train.py index ba69de43d255..2eb0a76ef137 100644 --- a/train.py +++ b/train.py @@ -294,7 +294,7 @@ def train(hyp, tb_writer, opt, device): loss, loss_items = compute_loss(pred, targets.to(device), model) # loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices. if local_rank != -1: - loss *= dist.get_world_size() + loss *= opt.world_size if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss_items) return results @@ -451,15 +451,17 @@ def train(hyp, tb_writer, opt, device): opt.total_batch_size = opt.batch_size if device.type == 'cpu': mixed_precision = False + opt.world_size = 1 elif opt.local_rank != -1: # DDP mode assert torch.cuda.device_count() > opt.local_rank torch.cuda.set_device(opt.local_rank) device = torch.device("cuda", opt.local_rank) dist.init_process_group(backend='nccl', init_method='env://') # distributed backend - - assert opt.batch_size % dist.get_world_size() == 0 - opt.batch_size = opt.total_batch_size // dist.get_world_size() + + opt.world_size = dist.get_world_size() + assert opt.batch_size % opt.world_size == 0 + opt.batch_size = opt.total_batch_size // opt.world_size print(opt) # Train