Skip to content

Commit

Permalink
DDP torch.jit.trace() --sync-bn fix (#4615)
Browse files Browse the repository at this point in the history
* Remove assert

* debug0

* trace=not opt.sync

* sync to sync_bn fix

* Cleanup
  • Loading branch information
glenn-jocher committed Aug 30, 2021
1 parent bb5ebc2 commit 50a9828
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn)
# end batch ------------------------------------------------------------------------------------------------

# Scheduler
Expand Down Expand Up @@ -499,7 +499,6 @@ def main(opt):
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
assert not opt.evolve, '--evolve argument is not compatible with DDP training'
assert not opt.sync_bn, '--sync-bn known training issue, see https://github.com/ultralytics/yolov5/issues/3998'
torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
Expand Down
9 changes: 5 additions & 4 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ def on_pretrain_routine_end(self):
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})

def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots, sync_bn):
# Callback runs on train batch end
if plots:
if ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if not sync_bn: # tb.add_graph() --sync known issue https://github.com/ultralytics/yolov5/issues/3754
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
Expand Down

0 comments on commit 50a9828

Please sign in to comment.