Skip to content

Commit

Permalink
Threaded TensorBoard graph logging (ultralytics#9070)
Browse files Browse the repository at this point in the history
* Log TensorBoard graph on pretrain_routine_end

* fix
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 0477dc8 commit 043236d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
model.half().float() # pre-reduce anchor precision

callbacks.run('on_pretrain_routine_end', labels, names, plots)
callbacks.run('on_pretrain_routine_end', labels, names)

# DDP mode
if cuda and RANK != -1:
Expand Down Expand Up @@ -328,7 +328,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%11s' * 2 + '%11.4g' * 5) %
(f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots)
callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
if callbacks.stop_training:
return
# end batch ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -420,7 +420,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
if is_coco:
callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)

callbacks.run('on_train_end', last, best, plots, epoch, results)
callbacks.run('on_train_end', last, best, epoch, results)

torch.cuda.empty_cache()
return results
Expand Down
34 changes: 18 additions & 16 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None,
self.weights = weights
self.opt = opt
self.hyp = hyp
self.plots = not opt.noplots # plot results
self.logger = logger # for printing results to console
self.include = include
self.keys = [
Expand Down Expand Up @@ -110,26 +111,26 @@ def on_train_start(self):
# Callback runs on train start
pass

def on_pretrain_routine_end(self, labels, names, plots):
def on_pretrain_routine_end(self, labels, names):
# Callback runs on pre-train routine end
if plots:
if self.plots:
plot_labels(labels, names, self.save_dir)
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
# if self.clearml:
# pass # ClearML saves these images automatically using hooks
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
# if self.clearml:
# pass # ClearML saves these images automatically using hooks

def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
def on_train_batch_end(self, model, ni, imgs, targets, paths):
# Callback runs on train batch end
# ni: number integrated batches (since train start)
if plots:
if ni == 0 and not self.opt.sync_bn and self.tb:
log_tensorboard_graph(self.tb, model, imgsz=list(imgs.shape[2:4]))
if self.plots:
if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename
plot_images(imgs, targets, paths, f)
if (self.wandb or self.clearml) and ni == 10:
if ni == 0 and self.tb and not self.opt.sync_bn:
log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz))
if ni == 10 and (self.wandb or self.clearml):
files = sorted(self.save_dir.glob('train*.jpg'))
if self.wandb:
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
Expand Down Expand Up @@ -197,9 +198,9 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
model_name='Latest Model',
auto_delete_file=False)

def on_train_end(self, last, best, plots, epoch, results):
def on_train_end(self, last, best, epoch, results):
# Callback runs on training end, i.e. saving best model
if plots:
if self.plots:
plot_results(file=self.save_dir / 'results.csv') # save results.png
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
Expand Down Expand Up @@ -291,6 +292,7 @@ def log_model(self, model_path, epoch=0, metadata={}):
wandb.log_artifact(art)


@threaded
def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
# Log model graph to TensorBoard
try:
Expand All @@ -300,5 +302,5 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
except Exception:
print('WARNING: TensorBoard graph visualization failure')
except Exception as e:
print(f'WARNING: TensorBoard graph visualization failure {e}')

0 comments on commit 043236d

Please sign in to comment.