From cd55b445c4dcd8003ff4b0b46b64adf7c16e5ce7 Mon Sep 17 00:00:00 2001 From: "yizhi.chen" Date: Wed, 15 Jul 2020 16:42:33 +0800 Subject: [PATCH] fix the DDP performance deterioration bug. --- train.py | 30 +++++++++++++++--------------- utils/datasets.py | 2 +- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/train.py b/train.py index e423a2b09c41..be3485ad02fb 100644 --- a/train.py +++ b/train.py @@ -70,7 +70,7 @@ def train(hyp, tb_writer, opt, device): # Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs. # Configure - init_seeds(1) + init_seeds(2+local_rank) with open(opt.data) as f: data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict train_path = data_dict['train'] @@ -208,18 +208,20 @@ def train(hyp, tb_writer, opt, device): model.names = names # Class frequency - labels = np.concatenate(dataset.labels, 0) - c = torch.tensor(labels[:, 0]) # classes - # cf = torch.bincount(c.long(), minlength=nc) + 1. - # model._initialize_biases(cf.to(device)) - plot_labels(labels, save_dir=log_dir) - if tb_writer: - tb_writer.add_hparams(hyp, {}) - tb_writer.add_histogram('classes', c, 0) - - # Check anchors - if not opt.noautoanchor: - check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) + # Only one check and log is needed. + if local_rank in [-1, 0]: + labels = np.concatenate(dataset.labels, 0) + c = torch.tensor(labels[:, 0]) # classes + # cf = torch.bincount(c.long(), minlength=nc) + 1. + # model._initialize_biases(cf.to(device)) + plot_labels(labels, save_dir=log_dir) + if tb_writer: + tb_writer.add_hparams(hyp, {}) + tb_writer.add_histogram('classes', c, 0) + + # Check anchors + if not opt.noautoanchor: + check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Start training t0 = time.time() @@ -460,8 +462,6 @@ def train(hyp, tb_writer, opt, device): dist.init_process_group(backend='nccl', init_method='env://') # distributed backend opt.world_size = dist.get_world_size() - assert opt.world_size <= 2, \ - "DDP mode with > 2 gpus will suffer from performance deterioration. The reason remains unknown!" assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!" opt.batch_size = opt.total_batch_size // opt.world_size print(opt) diff --git a/utils/datasets.py b/utils/datasets.py index 2da3940c3c95..a10b647f5839 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -305,7 +305,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r f += glob.iglob(p + os.sep + '*.*') else: raise Exception('%s does not exist' % p) - self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats] + self.img_files = sorted([x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]) except Exception as e: raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))