Skip to content

Commit

Permalink
fix the DDP performance deterioration bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhi.chen committed Jul 15, 2020
1 parent f5921ba commit cd55b44
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
30 changes: 15 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def train(hyp, tb_writer, opt, device):
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.

# Configure
init_seeds(1)
init_seeds(2+local_rank)
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
train_path = data_dict['train']
Expand Down Expand Up @@ -208,18 +208,20 @@ def train(hyp, tb_writer, opt, device):
model.names = names

# Class frequency
labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
# model._initialize_biases(cf.to(device))
plot_labels(labels, save_dir=log_dir)
if tb_writer:
tb_writer.add_hparams(hyp, {})
tb_writer.add_histogram('classes', c, 0)

# Check anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
# Only one check and log is needed.
if local_rank in [-1, 0]:
labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
# model._initialize_biases(cf.to(device))
plot_labels(labels, save_dir=log_dir)
if tb_writer:
tb_writer.add_hparams(hyp, {})
tb_writer.add_histogram('classes', c, 0)

# Check anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

# Start training
t0 = time.time()
Expand Down Expand Up @@ -460,8 +462,6 @@ def train(hyp, tb_writer, opt, device):
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

opt.world_size = dist.get_world_size()
assert opt.world_size <= 2, \
"DDP mode with > 2 gpus will suffer from performance deterioration. The reason remains unknown!"
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
opt.batch_size = opt.total_batch_size // opt.world_size
print(opt)
Expand Down
2 changes: 1 addition & 1 deletion utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
f += glob.iglob(p + os.sep + '*.*')
else:
raise Exception('%s does not exist' % p)
self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
self.img_files = sorted([x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
except Exception as e:
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))

Expand Down

0 comments on commit cd55b44

Please sign in to comment.